Skip to main content

mcp_proxy/
cache.rs

1//! Response caching middleware for the proxy.
2//!
3//! Caches `ReadResource` and `CallTool` responses with per-backend TTL.
4//! Cache keys are derived from the request type, name/URI, and arguments.
5//!
6//! # Cache Backends
7//!
8//! The cache backend is configurable via `[cache]` in the proxy config:
9//!
10//! - `"memory"` (default): In-process moka cache. Fast, no external deps.
11//! - `"redis"`: External Redis cache. Shared across instances. Requires the
12//!   `redis-cache` feature flag.
13//! - `"sqlite"`: Local SQLite cache. Persistent across restarts. Requires the
14//!   `sqlite-cache` feature flag.
15//!
16//! # Per-Backend Configuration
17//!
18//! ```toml
19//! [[backends]]
20//! name = "slow-api"
21//! transport = "http"
22//! url = "http://localhost:8080"
23//!
24//! [backends.cache]
25//! resource_ttl_seconds = 300
26//! tool_ttl_seconds = 60
27//! max_entries = 1000
28//! ```
29
30use std::convert::Infallible;
31use std::future::Future;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::sync::atomic::{AtomicU64, Ordering};
35use std::task::{Context, Poll};
36use std::time::Duration;
37
38use moka::future::Cache;
39use serde::Serialize;
40use tower::{Layer, Service};
41use tower_mcp::router::{RouterRequest, RouterResponse};
42use tower_mcp_types::protocol::McpRequest;
43
44use crate::config::{BackendCacheConfig, CacheBackendConfig};
45
46/// Pluggable cache storage backend.
47///
48/// Each variant provides the same logical operations (get, insert, invalidate,
49/// count) but differs in where entries are stored:
50///
51/// - [`Memory`](CacheStore::Memory): in-process moka cache (default)
52/// - [`Redis`](CacheStore::Redis): external Redis server (requires `redis-cache` feature)
53/// - [`Sqlite`](CacheStore::Sqlite): local SQLite database (requires `sqlite-cache` feature)
54#[derive(Clone)]
55pub(crate) enum CacheStore {
56    /// In-process moka cache.
57    Memory(Cache<String, RouterResponse>),
58    /// Redis-backed cache.
59    #[cfg(feature = "redis-cache")]
60    Redis {
61        client: redis::Client,
62        prefix: String,
63        ttl: Duration,
64    },
65    /// SQLite-backed cache.
66    #[cfg(feature = "sqlite-cache")]
67    Sqlite {
68        conn: Arc<std::sync::Mutex<rusqlite::Connection>>,
69        ttl: Duration,
70    },
71}
72
73impl CacheStore {
74    /// Retrieve a cached response by key.
75    async fn get(&self, key: &str) -> Option<RouterResponse> {
76        match self {
77            CacheStore::Memory(cache) => cache.get(key).await,
78            #[cfg(feature = "redis-cache")]
79            CacheStore::Redis {
80                client,
81                prefix,
82                ttl: _,
83            } => {
84                let full_key = format!("{prefix}{key}");
85                let mut conn = client.get_multiplexed_async_connection().await.ok()?;
86                let data: Option<String> =
87                    redis::AsyncCommands::get(&mut conn, &full_key).await.ok()?;
88                data.and_then(|s| serde_json::from_str(&s).ok())
89            }
90            #[cfg(feature = "sqlite-cache")]
91            CacheStore::Sqlite { conn, ttl: _ } => {
92                let key = key.to_string();
93                let conn = conn.lock().ok()?;
94                let now = std::time::SystemTime::now()
95                    .duration_since(std::time::UNIX_EPOCH)
96                    .unwrap_or_default()
97                    .as_secs() as i64;
98                let result: Option<String> = conn
99                    .query_row(
100                        "SELECT value FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
101                        rusqlite::params![key, now],
102                        |row| row.get(0),
103                    )
104                    .ok();
105                result.and_then(|s| serde_json::from_str(&s).ok())
106            }
107        }
108    }
109
110    /// Insert a response into the cache.
111    async fn insert(&self, key: String, value: RouterResponse) {
112        match self {
113            CacheStore::Memory(cache) => {
114                cache.insert(key, value).await;
115            }
116            #[cfg(feature = "redis-cache")]
117            CacheStore::Redis {
118                client,
119                prefix,
120                ttl,
121            } => {
122                let full_key = format!("{prefix}{key}");
123                if let Ok(json) = serde_json::to_string(&value)
124                    && let Ok(mut conn) = client.get_multiplexed_async_connection().await
125                {
126                    let _: Result<(), _> =
127                        redis::AsyncCommands::set_ex(&mut conn, &full_key, &json, ttl.as_secs())
128                            .await;
129                }
130            }
131            #[cfg(feature = "sqlite-cache")]
132            CacheStore::Sqlite { conn, ttl } => {
133                if let Ok(json) = serde_json::to_string(&value) {
134                    let expires_at = std::time::SystemTime::now()
135                        .duration_since(std::time::UNIX_EPOCH)
136                        .unwrap_or_default()
137                        .as_secs() as i64
138                        + ttl.as_secs() as i64;
139                    if let Ok(conn) = conn.lock() {
140                        let _ = conn.execute(
141                            "INSERT OR REPLACE INTO cache_entries (key, value, expires_at) VALUES (?1, ?2, ?3)",
142                            rusqlite::params![key, json, expires_at],
143                        );
144                    }
145                }
146            }
147        }
148    }
149
150    /// Remove all entries from the cache.
151    async fn invalidate_all(&self) {
152        match self {
153            CacheStore::Memory(cache) => {
154                cache.invalidate_all();
155            }
156            #[cfg(feature = "redis-cache")]
157            CacheStore::Redis {
158                client,
159                prefix,
160                ttl: _,
161            } => {
162                if let Ok(mut conn) = client.get_multiplexed_async_connection().await {
163                    let pattern = format!("{prefix}*");
164                    let keys: Vec<String> = redis::AsyncCommands::keys(&mut conn, &pattern)
165                        .await
166                        .unwrap_or_default();
167                    if !keys.is_empty() {
168                        let _: Result<(), _> = redis::AsyncCommands::del(&mut conn, &keys).await;
169                    }
170                }
171            }
172            #[cfg(feature = "sqlite-cache")]
173            CacheStore::Sqlite { conn, ttl: _ } => {
174                if let Ok(conn) = conn.lock() {
175                    let _ = conn.execute("DELETE FROM cache_entries", []);
176                }
177            }
178        }
179    }
180
181    /// Return the approximate number of entries in the cache.
182    async fn entry_count(&self) -> u64 {
183        match self {
184            CacheStore::Memory(cache) => cache.entry_count(),
185            #[cfg(feature = "redis-cache")]
186            CacheStore::Redis {
187                client,
188                prefix,
189                ttl: _,
190            } => {
191                if let Ok(mut conn) = client.get_multiplexed_async_connection().await {
192                    let pattern = format!("{prefix}*");
193                    let keys: Vec<String> = redis::AsyncCommands::keys(&mut conn, &pattern)
194                        .await
195                        .unwrap_or_default();
196                    keys.len() as u64
197                } else {
198                    0
199                }
200            }
201            #[cfg(feature = "sqlite-cache")]
202            CacheStore::Sqlite { conn, ttl: _ } => {
203                let now = std::time::SystemTime::now()
204                    .duration_since(std::time::UNIX_EPOCH)
205                    .unwrap_or_default()
206                    .as_secs() as i64;
207                if let Ok(conn) = conn.lock() {
208                    conn.query_row(
209                        "SELECT COUNT(*) FROM cache_entries WHERE expires_at > ?1",
210                        rusqlite::params![now],
211                        |row| row.get::<_, i64>(0),
212                    )
213                    .unwrap_or(0) as u64
214                } else {
215                    0
216                }
217            }
218        }
219    }
220}
221
222/// Build a [`CacheStore`] from the global cache backend configuration and
223/// a per-backend TTL.
224fn build_cache_store(
225    backend_config: &CacheBackendConfig,
226    ttl: Duration,
227    max_entries: u64,
228) -> CacheStore {
229    match backend_config.backend.as_str() {
230        #[cfg(feature = "redis-cache")]
231        "redis" => {
232            let url = backend_config.url.as_deref().unwrap_or("redis://127.0.0.1");
233            let client =
234                redis::Client::open(url).expect("invalid Redis URL in cache configuration");
235            CacheStore::Redis {
236                client,
237                prefix: backend_config.prefix.clone(),
238                ttl,
239            }
240        }
241        #[cfg(feature = "sqlite-cache")]
242        "sqlite" => {
243            let path = backend_config.url.as_deref().unwrap_or("cache.db");
244            let conn =
245                rusqlite::Connection::open(path).expect("failed to open SQLite cache database");
246            conn.execute_batch(
247                "CREATE TABLE IF NOT EXISTS cache_entries (
248                    key TEXT PRIMARY KEY,
249                    value TEXT NOT NULL,
250                    expires_at INTEGER NOT NULL
251                )",
252            )
253            .expect("failed to create SQLite cache table");
254            CacheStore::Sqlite {
255                conn: Arc::new(std::sync::Mutex::new(conn)),
256                ttl,
257            }
258        }
259        // Default: memory backend (also handles "memory" explicitly)
260        _ => CacheStore::Memory(
261            Cache::builder()
262                .max_capacity(max_entries)
263                .time_to_live(ttl)
264                .build(),
265        ),
266    }
267}
268
269/// Per-backend cache with separate resource and tool caches (different TTLs).
270#[derive(Clone)]
271struct BackendCache {
272    namespace: String,
273    resource_cache: Option<CacheStore>,
274    tool_cache: Option<CacheStore>,
275    stats: Arc<CacheStats>,
276}
277
278/// Atomic hit/miss counters for a backend cache.
279struct CacheStats {
280    hits: AtomicU64,
281    misses: AtomicU64,
282}
283
284impl CacheStats {
285    fn new() -> Self {
286        Self {
287            hits: AtomicU64::new(0),
288            misses: AtomicU64::new(0),
289        }
290    }
291}
292
293/// Snapshot of cache statistics for a single backend.
294///
295/// Returned by [`CacheHandle::stats()`] to report hit/miss rates
296/// and entry counts per cached namespace.
297#[derive(Serialize, Clone)]
298#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
299pub struct CacheStatsSnapshot {
300    /// Backend namespace this cache covers.
301    pub namespace: String,
302    /// Total cache hits.
303    pub hits: u64,
304    /// Total cache misses.
305    pub misses: u64,
306    /// Hit rate as a fraction (0.0-1.0).
307    pub hit_rate: f64,
308    /// Current number of cached entries.
309    pub entry_count: u64,
310}
311
312/// Shared handle for querying cache stats and clearing caches.
313#[derive(Clone)]
314pub struct CacheHandle {
315    caches: Arc<Vec<BackendCache>>,
316}
317
318impl CacheHandle {
319    /// Get a snapshot of cache statistics for all backends.
320    pub async fn stats(&self) -> Vec<CacheStatsSnapshot> {
321        let mut snapshots = Vec::with_capacity(self.caches.len());
322        for bc in self.caches.iter() {
323            let hits = bc.stats.hits.load(Ordering::Relaxed);
324            let misses = bc.stats.misses.load(Ordering::Relaxed);
325            let total = hits + misses;
326            let resource_count = match &bc.resource_cache {
327                Some(store) => store.entry_count().await,
328                None => 0,
329            };
330            let tool_count = match &bc.tool_cache {
331                Some(store) => store.entry_count().await,
332                None => 0,
333            };
334            snapshots.push(CacheStatsSnapshot {
335                namespace: bc.namespace.clone(),
336                hits,
337                misses,
338                hit_rate: if total > 0 {
339                    hits as f64 / total as f64
340                } else {
341                    0.0
342                },
343                entry_count: resource_count + tool_count,
344            });
345        }
346        snapshots
347    }
348
349    /// Clear all cache entries and reset stats.
350    pub async fn clear(&self) {
351        for bc in self.caches.iter() {
352            if let Some(store) = &bc.resource_cache {
353                store.invalidate_all().await;
354            }
355            if let Some(store) = &bc.tool_cache {
356                store.invalidate_all().await;
357            }
358            bc.stats.hits.store(0, Ordering::Relaxed);
359            bc.stats.misses.store(0, Ordering::Relaxed);
360        }
361    }
362}
363
364/// Build the shared cache state from per-backend configs.
365///
366/// Returns an `Arc<Vec<BackendCache>>` that can be shared between a
367/// [`CacheLayer`] (or [`CacheService`]) and its [`CacheHandle`].
368fn build_caches(
369    configs: Vec<(String, &BackendCacheConfig)>,
370    backend_config: &CacheBackendConfig,
371) -> Arc<Vec<BackendCache>> {
372    let caches: Vec<BackendCache> = configs
373        .into_iter()
374        .map(|(namespace, cfg)| {
375            let resource_cache = if cfg.resource_ttl_seconds > 0 {
376                Some(build_cache_store(
377                    backend_config,
378                    Duration::from_secs(cfg.resource_ttl_seconds),
379                    cfg.max_entries,
380                ))
381            } else {
382                None
383            };
384            let tool_cache = if cfg.tool_ttl_seconds > 0 {
385                Some(build_cache_store(
386                    backend_config,
387                    Duration::from_secs(cfg.tool_ttl_seconds),
388                    cfg.max_entries,
389                ))
390            } else {
391                None
392            };
393            BackendCache {
394                namespace,
395                resource_cache,
396                tool_cache,
397                stats: Arc::new(CacheStats::new()),
398            }
399        })
400        .collect();
401    Arc::new(caches)
402}
403
404/// Tower [`Layer`] that produces [`CacheService`] instances sharing the same
405/// cache state and [`CacheHandle`].
406///
407/// Because `CacheService::new()` returns a `(CacheService, CacheHandle)` tuple,
408/// a standard `Layer` cannot propagate the side-channel handle. `CacheLayer`
409/// solves this by creating the shared cache state up-front and handing out an
410/// `Arc`-cloned handle to the caller while cloning the same `Arc` into every
411/// service produced by [`Layer::layer`].
412///
413/// # Example
414///
415/// ```rust
416/// use mcp_proxy::cache::{CacheLayer, CacheHandle};
417/// use mcp_proxy::config::{BackendCacheConfig, CacheBackendConfig};
418///
419/// let cfg = BackendCacheConfig {
420///     resource_ttl_seconds: 300,
421///     tool_ttl_seconds: 60,
422///     max_entries: 1000,
423/// };
424/// let backend_cfg = CacheBackendConfig::default();
425///
426/// let (layer, handle) = CacheLayer::new(
427///     vec![("api/".to_string(), &cfg)],
428///     &backend_cfg,
429/// );
430///
431/// // `layer` implements `tower::Layer<S>` and can be used in a middleware stack.
432/// // `handle` can be used to query stats or clear the cache.
433/// ```
434#[derive(Clone)]
435pub struct CacheLayer {
436    caches: Arc<Vec<BackendCache>>,
437}
438
439impl CacheLayer {
440    /// Create a new cache layer and return it with a shareable [`CacheHandle`].
441    ///
442    /// The handle provides [`CacheHandle::stats()`] and [`CacheHandle::clear()`]
443    /// over the same underlying cache state used by every service the layer
444    /// produces.
445    pub fn new(
446        configs: Vec<(String, &BackendCacheConfig)>,
447        backend_config: &CacheBackendConfig,
448    ) -> (Self, CacheHandle) {
449        let caches = build_caches(configs, backend_config);
450        let handle = CacheHandle {
451            caches: Arc::clone(&caches),
452        };
453        (Self { caches }, handle)
454    }
455}
456
457impl<S> Layer<S> for CacheLayer {
458    type Service = CacheService<S>;
459
460    fn layer(&self, inner: S) -> Self::Service {
461        CacheService {
462            inner,
463            caches: Arc::clone(&self.caches),
464        }
465    }
466}
467
468/// Tower service that caches resource reads and tool call results.
469#[derive(Clone)]
470pub struct CacheService<S> {
471    inner: S,
472    caches: Arc<Vec<BackendCache>>,
473}
474
475impl<S> CacheService<S> {
476    /// Create a new cache service and return it with a shareable handle.
477    pub fn new(
478        inner: S,
479        configs: Vec<(String, &BackendCacheConfig)>,
480        backend_config: &CacheBackendConfig,
481    ) -> (Self, CacheHandle) {
482        let caches = build_caches(configs, backend_config);
483        let handle = CacheHandle {
484            caches: Arc::clone(&caches),
485        };
486        (Self { inner, caches }, handle)
487    }
488}
489
490/// Extract cache key and find the matching backend cache + stats.
491fn resolve_cache<'a>(
492    caches: &'a [BackendCache],
493    req: &McpRequest,
494) -> Option<(&'a CacheStore, String, &'a Arc<CacheStats>)> {
495    match req {
496        McpRequest::ReadResource(params) => {
497            let key = format!("res:{}", params.uri);
498            for bc in caches {
499                if params.uri.starts_with(&bc.namespace) {
500                    return bc.resource_cache.as_ref().map(|c| (c, key, &bc.stats));
501                }
502            }
503            None
504        }
505        McpRequest::CallTool(params) => {
506            let args = serde_json::to_string(&params.arguments).unwrap_or_default();
507            let key = format!("tool:{}:{}", params.name, args);
508            for bc in caches {
509                if params.name.starts_with(&bc.namespace) {
510                    return bc.tool_cache.as_ref().map(|c| (c, key, &bc.stats));
511                }
512            }
513            None
514        }
515        _ => None,
516    }
517}
518
519impl<S> Service<RouterRequest> for CacheService<S>
520where
521    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
522        + Clone
523        + Send
524        + 'static,
525    S::Future: Send,
526{
527    type Response = RouterResponse;
528    type Error = Infallible;
529    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
530
531    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
532        self.inner.poll_ready(cx)
533    }
534
535    fn call(&mut self, req: RouterRequest) -> Self::Future {
536        let caches = Arc::clone(&self.caches);
537
538        if let Some((store, key, stats)) = resolve_cache(&caches, &req.inner) {
539            let store = store.clone();
540            let stats = Arc::clone(stats);
541            let mut inner = self.inner.clone();
542
543            return Box::pin(async move {
544                // Cache hit -- return with current request ID
545                if let Some(cached) = store.get(&key).await {
546                    stats.hits.fetch_add(1, Ordering::Relaxed);
547                    return Ok(RouterResponse {
548                        id: req.id,
549                        inner: cached.inner,
550                    });
551                }
552
553                stats.misses.fetch_add(1, Ordering::Relaxed);
554                let result = inner.call(req).await;
555
556                // Only cache successful MCP responses
557                let Ok(ref resp) = result;
558                if resp.inner.is_ok() {
559                    store.insert(key, resp.clone()).await;
560                }
561
562                result
563            });
564        }
565
566        // No caching for this request type or backend
567        let fut = self.inner.call(req);
568        Box::pin(fut)
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use tower_mcp::protocol::{McpRequest, McpResponse};
575
576    use super::CacheService;
577    use crate::config::{BackendCacheConfig, CacheBackendConfig};
578    use crate::test_util::{MockService, call_service};
579
580    fn tool_call(name: &str) -> McpRequest {
581        McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
582            name: name.to_string(),
583            arguments: serde_json::json!({"key": "value"}),
584            meta: None,
585            task: None,
586        })
587    }
588
589    fn default_backend_config() -> CacheBackendConfig {
590        CacheBackendConfig::default()
591    }
592
593    #[tokio::test]
594    async fn test_cache_hit_returns_same_result() {
595        let mock = MockService::with_tools(&["fs/read"]);
596        let cfg = BackendCacheConfig {
597            resource_ttl_seconds: 60,
598            tool_ttl_seconds: 60,
599            max_entries: 100,
600        };
601        let (mut svc, _handle) = CacheService::new(
602            mock,
603            vec![("fs/".to_string(), &cfg)],
604            &default_backend_config(),
605        );
606
607        let resp1 = call_service(&mut svc, tool_call("fs/read")).await;
608        let resp2 = call_service(&mut svc, tool_call("fs/read")).await;
609
610        // Both should succeed with same content
611        match (resp1.inner.unwrap(), resp2.inner.unwrap()) {
612            (McpResponse::CallTool(r1), McpResponse::CallTool(r2)) => {
613                assert_eq!(r1.all_text(), r2.all_text());
614            }
615            _ => panic!("expected CallTool responses"),
616        }
617    }
618
619    #[tokio::test]
620    async fn test_cache_disabled_passes_through() {
621        let mock = MockService::with_tools(&["fs/read"]);
622        let cfg = BackendCacheConfig {
623            resource_ttl_seconds: 0,
624            tool_ttl_seconds: 0,
625            max_entries: 100,
626        };
627        let (mut svc, _handle) = CacheService::new(
628            mock,
629            vec![("fs/".to_string(), &cfg)],
630            &default_backend_config(),
631        );
632
633        let resp = call_service(&mut svc, tool_call("fs/read")).await;
634        assert!(resp.inner.is_ok());
635    }
636
637    #[tokio::test]
638    async fn test_cache_non_matching_namespace_passes_through() {
639        let mock = MockService::with_tools(&["db/query"]);
640        let cfg = BackendCacheConfig {
641            resource_ttl_seconds: 60,
642            tool_ttl_seconds: 60,
643            max_entries: 100,
644        };
645        let (mut svc, _handle) = CacheService::new(
646            mock,
647            vec![("fs/".to_string(), &cfg)],
648            &default_backend_config(),
649        );
650
651        let resp = call_service(&mut svc, tool_call("db/query")).await;
652        assert!(resp.inner.is_ok());
653    }
654
655    #[tokio::test]
656    async fn test_cache_list_tools_not_cached() {
657        let mock = MockService::with_tools(&["fs/read"]);
658        let cfg = BackendCacheConfig {
659            resource_ttl_seconds: 60,
660            tool_ttl_seconds: 60,
661            max_entries: 100,
662        };
663        let (mut svc, _handle) = CacheService::new(
664            mock,
665            vec![("fs/".to_string(), &cfg)],
666            &default_backend_config(),
667        );
668
669        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
670        assert!(resp.inner.is_ok(), "list_tools should pass through");
671    }
672
673    #[tokio::test]
674    async fn test_cache_stats_tracks_hits_and_misses() {
675        let mock = MockService::with_tools(&["fs/read"]);
676        let cfg = BackendCacheConfig {
677            resource_ttl_seconds: 60,
678            tool_ttl_seconds: 60,
679            max_entries: 100,
680        };
681        let (mut svc, handle) = CacheService::new(
682            mock,
683            vec![("fs/".to_string(), &cfg)],
684            &default_backend_config(),
685        );
686
687        // First call = miss
688        let _ = call_service(&mut svc, tool_call("fs/read")).await;
689        let stats = handle.stats().await;
690        assert_eq!(stats.len(), 1);
691        assert_eq!(stats[0].hits, 0);
692        assert_eq!(stats[0].misses, 1);
693
694        // Second call = hit
695        let _ = call_service(&mut svc, tool_call("fs/read")).await;
696        let stats = handle.stats().await;
697        assert_eq!(stats[0].hits, 1);
698        assert_eq!(stats[0].misses, 1);
699        assert!((stats[0].hit_rate - 0.5).abs() < f64::EPSILON);
700    }
701
702    #[tokio::test]
703    async fn test_cache_clear_resets_stats() {
704        let mock = MockService::with_tools(&["fs/read"]);
705        let cfg = BackendCacheConfig {
706            resource_ttl_seconds: 60,
707            tool_ttl_seconds: 60,
708            max_entries: 100,
709        };
710        let (mut svc, handle) = CacheService::new(
711            mock,
712            vec![("fs/".to_string(), &cfg)],
713            &default_backend_config(),
714        );
715
716        let _ = call_service(&mut svc, tool_call("fs/read")).await;
717        let _ = call_service(&mut svc, tool_call("fs/read")).await;
718
719        handle.clear().await;
720        let stats = handle.stats().await;
721        assert_eq!(stats[0].hits, 0);
722        assert_eq!(stats[0].misses, 0);
723    }
724
725    #[tokio::test]
726    async fn test_cache_layer_produces_working_service() {
727        use super::CacheLayer;
728        use tower::Layer;
729
730        let cfg = BackendCacheConfig {
731            resource_ttl_seconds: 60,
732            tool_ttl_seconds: 60,
733            max_entries: 100,
734        };
735        let (layer, handle) =
736            CacheLayer::new(vec![("fs/".to_string(), &cfg)], &default_backend_config());
737
738        let mock = MockService::with_tools(&["fs/read"]);
739        let mut svc = layer.layer(mock);
740
741        // First call = miss
742        let _ = call_service(&mut svc, tool_call("fs/read")).await;
743        let stats = handle.stats().await;
744        assert_eq!(stats[0].misses, 1);
745        assert_eq!(stats[0].hits, 0);
746
747        // Second call = hit (cached)
748        let _ = call_service(&mut svc, tool_call("fs/read")).await;
749        let stats = handle.stats().await;
750        assert_eq!(stats[0].hits, 1);
751        assert_eq!(stats[0].misses, 1);
752    }
753
754    #[tokio::test]
755    async fn test_cache_layer_shares_state_across_services() {
756        use super::CacheLayer;
757        use tower::Layer;
758
759        let cfg = BackendCacheConfig {
760            resource_ttl_seconds: 60,
761            tool_ttl_seconds: 60,
762            max_entries: 100,
763        };
764        let (layer, handle) =
765            CacheLayer::new(vec![("fs/".to_string(), &cfg)], &default_backend_config());
766
767        // Create two services from the same layer
768        let mock1 = MockService::with_tools(&["fs/read"]);
769        let mut svc1 = layer.layer(mock1);
770
771        let mock2 = MockService::with_tools(&["fs/read"]);
772        let mut svc2 = layer.layer(mock2);
773
774        // Miss on svc1
775        let _ = call_service(&mut svc1, tool_call("fs/read")).await;
776        assert_eq!(handle.stats().await[0].misses, 1);
777
778        // Hit on svc2 (same underlying cache)
779        let _ = call_service(&mut svc2, tool_call("fs/read")).await;
780        assert_eq!(handle.stats().await[0].hits, 1);
781        assert_eq!(handle.stats().await[0].misses, 1);
782    }
783
784    #[tokio::test]
785    async fn test_cache_layer_handle_clear() {
786        use super::CacheLayer;
787        use tower::Layer;
788
789        let cfg = BackendCacheConfig {
790            resource_ttl_seconds: 60,
791            tool_ttl_seconds: 60,
792            max_entries: 100,
793        };
794        let (layer, handle) =
795            CacheLayer::new(vec![("fs/".to_string(), &cfg)], &default_backend_config());
796
797        let mock = MockService::with_tools(&["fs/read"]);
798        let mut svc = layer.layer(mock);
799
800        let _ = call_service(&mut svc, tool_call("fs/read")).await;
801        let _ = call_service(&mut svc, tool_call("fs/read")).await;
802        assert_eq!(handle.stats().await[0].hits, 1);
803
804        handle.clear().await;
805        let stats = handle.stats().await;
806        assert_eq!(stats[0].hits, 0);
807        assert_eq!(stats[0].misses, 0);
808    }
809
810    #[tokio::test]
811    async fn test_cache_store_memory_get_insert() {
812        use super::{CacheStore, build_cache_store};
813
814        let store = build_cache_store(&default_backend_config(), Duration::from_secs(60), 100);
815        assert!(matches!(store, CacheStore::Memory(_)));
816
817        // Initially empty
818        assert!(store.get("key1").await.is_none());
819        assert_eq!(store.entry_count().await, 0);
820    }
821
822    #[cfg(feature = "redis-cache")]
823    #[test]
824    fn test_cache_store_redis_construction() {
825        use super::build_cache_store;
826
827        let cfg = CacheBackendConfig {
828            backend: "redis".to_string(),
829            url: Some("redis://127.0.0.1:6379".to_string()),
830            prefix: "test:".to_string(),
831        };
832        let store = build_cache_store(&cfg, Duration::from_secs(60), 100);
833        assert!(matches!(store, super::CacheStore::Redis { .. }));
834    }
835
836    #[cfg(feature = "sqlite-cache")]
837    #[tokio::test]
838    async fn test_cache_store_sqlite_construction() {
839        use super::build_cache_store;
840
841        let dir = std::env::temp_dir().join(format!("mcp-proxy-test-{}", std::process::id()));
842        std::fs::create_dir_all(&dir).unwrap();
843        let db_path = dir.join("test_cache.db");
844
845        let cfg = CacheBackendConfig {
846            backend: "sqlite".to_string(),
847            url: Some(db_path.to_string_lossy().to_string()),
848            prefix: "test:".to_string(),
849        };
850        let store = build_cache_store(&cfg, Duration::from_secs(60), 100);
851        assert!(matches!(store, super::CacheStore::Sqlite { .. }));
852        assert_eq!(store.entry_count().await, 0);
853
854        // Clean up
855        let _ = std::fs::remove_dir_all(&dir);
856    }
857
858    use std::time::Duration;
859}