use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::RwLock;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
pub const DEFAULT_TOOL_CACHE_MAX_ENTRIES: usize = 1000;
pub const DEFAULT_TOOL_CACHE_DEFAULT_TTL_SECS: u64 = 30 * 60;
pub const DEFAULT_TOOL_CACHE_MAX_TTL_SECS: u64 = 30 * 60;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCacheConfig {
pub max_entries: usize,
pub default_ttl_secs: u64,
pub max_ttl_secs: u64,
}
impl Default for ToolCacheConfig {
fn default() -> Self {
Self {
max_entries: DEFAULT_TOOL_CACHE_MAX_ENTRIES,
default_ttl_secs: DEFAULT_TOOL_CACHE_DEFAULT_TTL_SECS,
max_ttl_secs: DEFAULT_TOOL_CACHE_MAX_TTL_SECS,
}
}
}
#[derive(Clone, Debug)]
struct ToolCacheEntry {
tool_name: String,
value: Value,
created_seq: u64,
expires_at: Instant,
}
#[derive(Default)]
struct ToolCacheStore {
entries: HashMap<String, ToolCacheEntry>,
}
pub struct SharedToolCache {
store: RwLock<ToolCacheStore>,
config: ToolCacheConfig,
counter: AtomicU64,
}
impl SharedToolCache {
pub fn new(config: ToolCacheConfig) -> Self {
Self {
store: RwLock::new(ToolCacheStore::default()),
config,
counter: AtomicU64::new(1),
}
}
pub fn create(&self, tool_name: &str, value: Value, ttl_secs: Option<u64>) -> String {
let now = Instant::now();
let ttl = self.resolve_ttl(ttl_secs);
let cache_id = self.next_cache_id();
let entry = ToolCacheEntry {
tool_name: tool_name.to_string(),
value,
created_seq: self.counter.fetch_add(1, Ordering::Relaxed),
expires_at: now + ttl,
};
let mut store = self.store.write().expect("tool cache poisoned");
self.cleanup_expired_locked(&mut store, now);
store.entries.insert(cache_id.clone(), entry);
self.enforce_capacity_locked(&mut store);
cache_id
}
pub fn get(&self, tool_name: &str, cache_id: &str) -> Option<Value> {
let now = Instant::now();
{
let store = self.store.read().expect("tool cache poisoned");
if let Some(entry) = store.entries.get(cache_id)
&& entry.tool_name == tool_name
&& entry.expires_at > now
{
return Some(entry.value.clone());
}
}
let mut store = self.store.write().expect("tool cache poisoned");
self.cleanup_expired_locked(&mut store, now);
match store.entries.get(cache_id) {
Some(entry) if entry.tool_name == tool_name && entry.expires_at > now => {
Some(entry.value.clone())
}
_ => None,
}
}
pub fn delete(&self, tool_name: &str, cache_id: &str) -> bool {
let mut store = self.store.write().expect("tool cache poisoned");
self.cleanup_expired_locked(&mut store, Instant::now());
if let Some(entry) = store.entries.get(cache_id)
&& entry.tool_name == tool_name
{
store.entries.remove(cache_id);
return true;
}
false
}
fn resolve_ttl(&self, ttl_secs: Option<u64>) -> Duration {
let requested = ttl_secs.unwrap_or(self.config.default_ttl_secs);
let clamped = requested.max(1).min(self.config.max_ttl_secs.max(1));
Duration::from_secs(clamped)
}
fn next_cache_id(&self) -> String {
let unix_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let seq = self.counter.fetch_add(1, Ordering::Relaxed);
format!("tc-{}-{}", unix_ms, seq)
}
fn cleanup_expired_locked(&self, store: &mut ToolCacheStore, now: Instant) {
store.entries.retain(|_, entry| entry.expires_at > now);
}
fn enforce_capacity_locked(&self, store: &mut ToolCacheStore) {
while store.entries.len() > self.config.max_entries.max(1) {
let oldest_id = store
.entries
.iter()
.min_by_key(|(_, entry)| entry.created_seq)
.map(|(cache_id, _)| cache_id.clone());
match oldest_id {
Some(cache_id) => {
store.entries.remove(&cache_id);
}
None => break,
}
}
}
}
static GLOBAL_TOOL_CACHE: OnceLock<Arc<SharedToolCache>> = OnceLock::new();
pub fn configure_global_tool_cache(config: ToolCacheConfig) {
let _ = GLOBAL_TOOL_CACHE.set(Arc::new(SharedToolCache::new(config)));
}
pub fn global_tool_cache() -> Arc<SharedToolCache> {
GLOBAL_TOOL_CACHE
.get_or_init(|| Arc::new(SharedToolCache::new(ToolCacheConfig::default())))
.clone()
}
#[cfg(test)]
mod tests {
use super::{SharedToolCache, ToolCacheConfig};
use serde_json::json;
use std::thread;
use std::time::Duration;
fn test_cache_config(
max_entries: usize,
default_ttl_secs: u64,
max_ttl_secs: u64,
) -> ToolCacheConfig {
ToolCacheConfig {
max_entries,
default_ttl_secs,
max_ttl_secs,
}
}
#[test]
fn cache_entries_are_isolated_by_tool_name() {
let cache = SharedToolCache::new(test_cache_config(10, 5, 5));
let cache_id = cache.create("skill-a", json!({"value": 1}), None);
assert_eq!(cache.get("skill-a", &cache_id), Some(json!({"value": 1})));
assert_eq!(cache.get("skill-b", &cache_id), None);
}
#[test]
fn cache_entries_expire_after_default_ttl() {
let cache = SharedToolCache::new(test_cache_config(10, 1, 1));
let cache_id = cache.create("skill-a", json!({"value": 1}), None);
thread::sleep(Duration::from_millis(1100));
assert_eq!(cache.get("skill-a", &cache_id), None);
}
#[test]
fn cache_requested_ttl_is_clamped_to_maximum() {
let cache = SharedToolCache::new(test_cache_config(10, 5, 1));
let cache_id = cache.create("skill-a", json!({"value": 1}), Some(60));
thread::sleep(Duration::from_millis(1100));
assert_eq!(cache.get("skill-a", &cache_id), None);
}
#[test]
fn cache_evicts_oldest_entry_when_capacity_is_exceeded() {
let cache = SharedToolCache::new(test_cache_config(2, 5, 5));
let first_id = cache.create("skill-a", json!({"value": 1}), None);
let second_id = cache.create("skill-a", json!({"value": 2}), None);
let third_id = cache.create("skill-a", json!({"value": 3}), None);
assert_eq!(cache.get("skill-a", &first_id), None);
assert_eq!(cache.get("skill-a", &second_id), Some(json!({"value": 2})));
assert_eq!(cache.get("skill-a", &third_id), Some(json!({"value": 3})));
}
}