use super::request::{McpRequest, McpResponse};
use futures_util::future::BoxFuture;
use parking_lot::RwLock;
use serde_json::Value;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tower_layer::Layer;
use tower_service::Service;
use turbomcp_protocol::McpError;
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub cache_methods: Vec<String>,
pub exclude_methods: Vec<String>,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
ttl: Duration::from_secs(300), cache_methods: Vec::new(),
exclude_methods: vec![
"tools/call".to_string(),
"sampling/createMessage".to_string(),
"notifications/".to_string(),
],
}
}
}
impl CacheConfig {
fn should_cache(&self, method: &str) -> bool {
for excluded in &self.exclude_methods {
if method.starts_with(excluded) || method == excluded {
return false;
}
}
if !self.cache_methods.is_empty() {
return self
.cache_methods
.iter()
.any(|m| method.starts_with(m) || method == m);
}
method.starts_with("resources/")
|| method.starts_with("prompts/")
|| method == "tools/list"
|| method == "resources/list"
|| method == "prompts/list"
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
data: Value,
created: Instant,
last_accessed: Instant,
access_count: u64,
}
impl CacheEntry {
fn new(data: Value) -> Self {
let now = Instant::now();
Self {
data,
created: now,
last_accessed: now,
access_count: 0,
}
}
fn is_expired(&self, ttl: Duration) -> bool {
self.created.elapsed() > ttl
}
fn access(&mut self) -> &Value {
self.last_accessed = Instant::now();
self.access_count += 1;
&self.data
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub expirations: u64,
pub current_entries: usize,
}
#[derive(Debug)]
pub struct Cache {
config: CacheConfig,
entries: RwLock<HashMap<String, CacheEntry>>,
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
expirations: AtomicU64,
}
impl Cache {
#[must_use]
pub fn new(config: CacheConfig) -> Self {
Self {
config,
entries: RwLock::new(HashMap::new()),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
evictions: AtomicU64::new(0),
expirations: AtomicU64::new(0),
}
}
fn cache_key(req: &McpRequest) -> String {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
req.method().hash(&mut hasher);
if let Some(params) = req.params() {
params.to_string().hash(&mut hasher);
}
format!("{}:{:x}", req.method(), hasher.finish())
}
pub fn should_cache(&self, method: &str) -> bool {
self.config.should_cache(method)
}
pub fn get(&self, key: &str) -> Option<Value> {
let mut entries = self.entries.write();
if let Some(entry) = entries.get_mut(key) {
if entry.is_expired(self.config.ttl) {
entries.remove(key);
self.expirations.fetch_add(1, Ordering::Relaxed);
self.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
self.hits.fetch_add(1, Ordering::Relaxed);
return Some(entry.access().clone());
}
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub fn put(&self, key: String, value: Value) {
let mut entries = self.entries.write();
if entries.len() >= self.config.max_entries {
self.evict_lru(&mut entries);
}
entries.insert(key, CacheEntry::new(value));
}
fn evict_lru(&self, entries: &mut HashMap<String, CacheEntry>) {
let mut to_evict: Vec<_> = entries
.iter()
.map(|(k, v)| (k.clone(), v.last_accessed))
.collect();
to_evict.sort_by_key(|(_, accessed)| *accessed);
let evict_count = (entries.len() / 10).max(1);
for (key, _) in to_evict.into_iter().take(evict_count) {
entries.remove(&key);
self.evictions.fetch_add(1, Ordering::Relaxed);
}
}
#[must_use]
pub fn stats(&self) -> CacheStats {
CacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
expirations: self.expirations.load(Ordering::Relaxed),
current_entries: self.entries.read().len(),
}
}
pub fn clear(&self) {
self.entries.write().clear();
}
pub fn cleanup(&self) {
let mut entries = self.entries.write();
let ttl = self.config.ttl;
let expired: Vec<_> = entries
.iter()
.filter(|(_, e)| e.is_expired(ttl))
.map(|(k, _)| k.clone())
.collect();
for key in expired {
entries.remove(&key);
self.expirations.fetch_add(1, Ordering::Relaxed);
}
}
}
impl Default for Cache {
fn default() -> Self {
Self::new(CacheConfig::default())
}
}
#[derive(Debug, Clone)]
pub struct CacheLayer {
cache: Arc<Cache>,
}
impl CacheLayer {
#[must_use]
pub fn new(config: CacheConfig) -> Self {
Self {
cache: Arc::new(Cache::new(config)),
}
}
#[must_use]
pub fn with_cache(cache: Arc<Cache>) -> Self {
Self { cache }
}
#[must_use]
pub fn cache(&self) -> &Arc<Cache> {
&self.cache
}
}
impl Default for CacheLayer {
fn default() -> Self {
Self::new(CacheConfig::default())
}
}
impl<S> Layer<S> for CacheLayer {
type Service = CacheService<S>;
fn layer(&self, inner: S) -> Self::Service {
CacheService {
inner,
cache: Arc::clone(&self.cache),
}
}
}
#[derive(Debug, Clone)]
pub struct CacheService<S> {
inner: S,
cache: Arc<Cache>,
}
impl<S> CacheService<S> {
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn cache(&self) -> &Arc<Cache> {
&self.cache
}
}
impl<S> Service<McpRequest> for CacheService<S>
where
S: Service<McpRequest, Response = McpResponse> + Clone + Send + 'static,
S::Future: Send,
S::Error: Into<McpError>,
{
type Response = McpResponse;
type Error = McpError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: McpRequest) -> Self::Future {
let method = req.method().to_string();
let cache = Arc::clone(&self.cache);
if !cache.should_cache(&method) {
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
return Box::pin(async move { inner.call(req).await.map_err(Into::into) });
}
let cache_key = Cache::cache_key(&req);
if let Some(cached_value) = cache.get(&cache_key) {
return Box::pin(async move {
Ok(McpResponse {
result: Some(cached_value),
error: None,
metadata: {
let mut m = HashMap::new();
m.insert("cache.hit".to_string(), serde_json::json!(true));
m
},
duration: Duration::ZERO,
})
});
}
let mut inner = self.inner.clone();
std::mem::swap(&mut self.inner, &mut inner);
Box::pin(async move {
let start = Instant::now();
let result = inner.call(req).await.map_err(Into::into)?;
if result.is_success()
&& let Some(ref data) = result.result
{
cache.put(cache_key, data.clone());
}
let mut response = result;
response.insert_metadata("cache.hit", serde_json::json!(false));
response.duration = start.elapsed();
Ok(response)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use turbomcp_protocol::MessageId;
use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
fn test_request(method: &str) -> McpRequest {
McpRequest::new(JsonRpcRequest {
jsonrpc: JsonRpcVersion,
id: MessageId::from("test-1"),
method: method.to_string(),
params: Some(json!({"key": "value"})),
})
}
#[test]
fn test_cache_config_defaults() {
let config = CacheConfig::default();
assert!(config.should_cache("resources/list"));
assert!(config.should_cache("resources/read"));
assert!(config.should_cache("prompts/list"));
assert!(config.should_cache("tools/list"));
assert!(!config.should_cache("tools/call"));
assert!(!config.should_cache("sampling/createMessage"));
}
#[test]
fn test_cache_put_get() {
let cache = Cache::default();
let key = "test:123".to_string();
let value = json!({"result": "test"});
cache.put(key.clone(), value.clone());
let retrieved = cache.get(&key);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), value);
}
#[test]
fn test_cache_miss() {
let cache = Cache::default();
let retrieved = cache.get("nonexistent");
assert!(retrieved.is_none());
let stats = cache.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
}
#[test]
fn test_cache_expiration() {
let config = CacheConfig {
ttl: Duration::from_millis(1),
..Default::default()
};
let cache = Cache::new(config);
let key = "test:456".to_string();
cache.put(key.clone(), json!({"data": "test"}));
std::thread::sleep(Duration::from_millis(5));
let retrieved = cache.get(&key);
assert!(retrieved.is_none());
let stats = cache.stats();
assert_eq!(stats.expirations, 1);
}
#[test]
fn test_cache_eviction() {
let config = CacheConfig {
max_entries: 2,
ttl: Duration::from_secs(300),
..Default::default()
};
let cache = Cache::new(config);
cache.put("key1".to_string(), json!(1));
cache.put("key2".to_string(), json!(2));
cache.put("key3".to_string(), json!(3));
let stats = cache.stats();
assert!(stats.evictions > 0);
assert!(stats.current_entries <= 2);
}
#[test]
fn test_cache_key_generation() {
let req1 = test_request("resources/read");
let req2 = test_request("resources/read");
let req3 = test_request("resources/list");
assert_eq!(Cache::cache_key(&req1), Cache::cache_key(&req2));
assert_ne!(Cache::cache_key(&req1), Cache::cache_key(&req3));
}
#[tokio::test]
async fn test_cache_service() {
use tower::ServiceExt;
let cache = Arc::new(Cache::default());
let call_count = Arc::new(AtomicU64::new(0));
let call_count_clone = Arc::clone(&call_count);
let mock_service = tower::service_fn(move |_req: McpRequest| {
let count = Arc::clone(&call_count_clone);
async move {
count.fetch_add(1, Ordering::Relaxed);
Ok::<_, McpError>(McpResponse::success(
json!({"result": "data"}),
Duration::from_millis(10),
))
}
});
let mut service = CacheLayer::with_cache(Arc::clone(&cache)).layer(mock_service);
let request = test_request("resources/list");
let response = service
.ready()
.await
.unwrap()
.call(request.clone())
.await
.unwrap();
assert!(response.is_success());
assert_eq!(call_count.load(Ordering::Relaxed), 1);
let mut service = CacheLayer::with_cache(Arc::clone(&cache)).layer(tower::service_fn(
|_req: McpRequest| async {
panic!("Inner service should not be called on cache hit");
#[allow(unreachable_code)]
Ok::<_, McpError>(McpResponse::success(json!({}), Duration::ZERO))
},
));
let response = service.ready().await.unwrap().call(request).await.unwrap();
assert!(response.is_success());
assert_eq!(response.get_metadata("cache.hit"), Some(&json!(true)));
}
}