use std::collections::HashMap;
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use everruns_core::ToolDefinition;
use tokio::sync::Mutex as AsyncMutex;
use uuid::Uuid;
pub(crate) const TOOL_CACHE_TTL: Duration = Duration::from_secs(3600);
type CacheKey = (Uuid, String);
struct CacheEntry {
tools: Vec<ToolDefinition>,
cached_at: Instant,
}
enum Freshness {
Fresh(Vec<ToolDefinition>),
Stale(Vec<ToolDefinition>),
Cold,
}
#[derive(Default)]
struct KeyedLocks {
map: Mutex<HashMap<CacheKey, Arc<AsyncMutex<()>>>>,
}
impl KeyedLocks {
fn lock_for(&self, key: &CacheKey) -> Arc<AsyncMutex<()>> {
let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
map.retain(|k, lock| k == key || Arc::strong_count(lock) > 1);
map.entry(key.clone()).or_default().clone()
}
}
pub(crate) struct McpDiscoveryCache {
entries: Mutex<HashMap<CacheKey, CacheEntry>>,
locks: KeyedLocks,
ttl: Duration,
}
impl McpDiscoveryCache {
pub(crate) fn new() -> Self {
Self::with_ttl(TOOL_CACHE_TTL)
}
fn with_ttl(ttl: Duration) -> Self {
Self {
entries: Mutex::new(HashMap::new()),
locks: KeyedLocks::default(),
ttl,
}
}
fn entries(&self) -> std::sync::MutexGuard<'_, HashMap<CacheKey, CacheEntry>> {
self.entries.lock().unwrap_or_else(|e| e.into_inner())
}
fn classify(&self, key: &CacheKey, now: Instant) -> Freshness {
match self.entries().get(key) {
None => Freshness::Cold,
Some(entry) if now.duration_since(entry.cached_at) < self.ttl => {
Freshness::Fresh(entry.tools.clone())
}
Some(entry) => Freshness::Stale(entry.tools.clone()),
}
}
fn store(&self, key: CacheKey, tools: Vec<ToolDefinition>, now: Instant) {
self.entries().insert(
key,
CacheEntry {
tools,
cached_at: now,
},
);
}
pub(crate) async fn resolve<F, Fut>(
self: &Arc<Self>,
key: CacheKey,
refresh: F,
) -> Vec<ToolDefinition>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Option<Vec<ToolDefinition>>> + Send + 'static,
{
match self.classify(&key, Instant::now()) {
Freshness::Fresh(tools) => tools,
Freshness::Stale(tools) => {
self.spawn_background_refresh(key, refresh);
tools
}
Freshness::Cold => self.refresh_coalesced(key, refresh).await,
}
}
async fn refresh_coalesced<F, Fut>(
self: &Arc<Self>,
key: CacheKey,
refresh: F,
) -> Vec<ToolDefinition>
where
F: Fn() -> Fut,
Fut: Future<Output = Option<Vec<ToolDefinition>>>,
{
let lock = self.locks.lock_for(&key);
let _guard = lock.lock_owned().await;
if let Freshness::Fresh(tools) = self.classify(&key, Instant::now()) {
return tools;
}
match refresh().await {
Some(tools) => {
self.store(key, tools.clone(), Instant::now());
tools
}
None => match self.classify(&key, Instant::now()) {
Freshness::Fresh(tools) | Freshness::Stale(tools) => tools,
Freshness::Cold => Vec::new(),
},
}
}
fn spawn_background_refresh<F, Fut>(self: &Arc<Self>, key: CacheKey, refresh: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Option<Vec<ToolDefinition>>> + Send + 'static,
{
let lock = self.locks.lock_for(&key);
let Ok(guard) = lock.try_lock_owned() else {
return; };
let cache = self.clone();
tokio::spawn(async move {
let _guard = guard; if let Some(tools) = refresh().await {
cache.store(key, tools, Instant::now());
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use everruns_core::tool_types::{BuiltinTool, DeferrablePolicy, ToolHints, ToolPolicy};
use std::sync::atomic::{AtomicUsize, Ordering};
fn def(name: &str) -> ToolDefinition {
ToolDefinition::Builtin(BuiltinTool {
name: name.to_string(),
display_name: None,
description: String::new(),
parameters: serde_json::json!({}),
policy: ToolPolicy::default(),
category: None,
deferrable: DeferrablePolicy::default(),
hints: ToolHints::default(),
full_parameters: None,
})
}
fn names(tools: &[ToolDefinition]) -> Vec<String> {
tools.iter().map(|t| t.name().to_string()).collect()
}
fn key() -> CacheKey {
(Uuid::nil(), "docs".to_string())
}
#[tokio::test]
async fn cold_fetches_then_serves_from_cache() {
let cache = Arc::new(McpDiscoveryCache::new());
let calls = Arc::new(AtomicUsize::new(0));
let refresh = {
let calls = calls.clone();
move || {
let calls = calls.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
Some(vec![def("docs__search")])
}
}
};
let first = cache.resolve(key(), refresh.clone()).await;
assert_eq!(names(&first), ["docs__search"]);
assert_eq!(calls.load(Ordering::SeqCst), 1);
let second = cache.resolve(key(), refresh).await;
assert_eq!(names(&second), ["docs__search"]);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn cold_failure_returns_empty_and_does_not_cache() {
let cache = Arc::new(McpDiscoveryCache::new());
let out = cache
.resolve(key(), || async { None::<Vec<ToolDefinition>> })
.await;
assert!(out.is_empty());
assert!(matches!(
cache.classify(&key(), Instant::now()),
Freshness::Cold
));
}
#[tokio::test]
async fn concurrent_cold_callers_are_single_flight() {
let cache = Arc::new(McpDiscoveryCache::new());
let calls = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..16 {
let cache = cache.clone();
let calls = calls.clone();
handles.push(tokio::spawn(async move {
cache
.resolve(key(), move || {
let calls = calls.clone();
async move {
tokio::task::yield_now().await;
calls.fetch_add(1, Ordering::SeqCst);
Some(vec![def("docs__search")])
}
})
.await
}));
}
for h in handles {
assert_eq!(names(&h.await.unwrap()), ["docs__search"]);
}
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"single-flight must fetch once"
);
}
#[tokio::test]
async fn stale_serves_cached_then_revalidates_in_background() {
let cache = Arc::new(McpDiscoveryCache::with_ttl(Duration::ZERO));
cache.store(key(), vec![def("v1")], Instant::now());
let calls = Arc::new(AtomicUsize::new(0));
let stale = cache
.resolve(key(), {
let calls = calls.clone();
move || {
let calls = calls.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
Some(vec![def("v2")])
}
}
})
.await;
assert_eq!(names(&stale), ["v1"]);
let mut updated = false;
for _ in 0..200 {
tokio::task::yield_now().await;
if let Freshness::Stale(tools) | Freshness::Fresh(tools) =
cache.classify(&key(), Instant::now())
&& names(&tools) == ["v2"]
{
updated = true;
break;
}
}
assert!(updated, "background refresh must replace the stale entry");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
}