use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::compression::engine::Tool;
use crate::Error;
pub trait ToolBackend: Send + Sync {
fn list_tools(&self) -> impl std::future::Future<Output = Result<Vec<Tool>, Error>> + Send;
}
pub struct ToolCache<B: ToolBackend> {
backend: B,
cache: Arc<RwLock<Option<CachedTools>>>,
populated: Arc<AtomicBool>,
generation: Arc<AtomicU64>,
include: Option<Vec<String>>,
exclude: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
struct CachedTools {
generation: u64,
tools: Vec<Tool>,
}
impl<B: ToolBackend> ToolCache<B> {
pub fn new(
backend: B,
include: Option<Vec<String>>,
exclude: Option<Vec<String>>,
) -> Self {
Self {
backend,
cache: Arc::new(RwLock::new(None)),
populated: Arc::new(AtomicBool::new(false)),
generation: Arc::new(AtomicU64::new(0)),
include,
exclude,
}
}
pub fn is_populated(&self) -> bool {
self.populated.load(Ordering::SeqCst)
}
pub async fn get_all(&self) -> Result<Vec<Tool>, Error> {
let current_generation = self.generation.load(Ordering::SeqCst);
if let Some(cached) = self.cache.read().await.as_ref() {
if cached.generation == current_generation {
return Ok(cached.tools.clone());
}
}
let mut cache = self.cache.write().await;
let current_generation = self.generation.load(Ordering::SeqCst);
if let Some(cached) = cache.as_ref() {
if cached.generation == current_generation {
return Ok(cached.tools.clone());
}
}
let tools = self.fetch_filtered().await?;
*cache = Some(CachedTools {
generation: current_generation,
tools: tools.clone(),
});
self.populated.store(true, Ordering::SeqCst);
Ok(tools)
}
pub async fn get(&self, name: &str) -> Result<Option<Tool>, Error> {
Ok(self.get_all().await?.into_iter().find(|tool| tool.name == name))
}
pub async fn refresh(&self) -> Result<(), Error> {
let tools = self.fetch_filtered().await?;
let generation = self.generation.load(Ordering::SeqCst);
*self.cache.write().await = Some(CachedTools { generation, tools });
self.populated.store(true, Ordering::SeqCst);
Ok(())
}
pub fn invalidate(&self) {
self.generation.fetch_add(1, Ordering::SeqCst);
self.populated.store(false, Ordering::SeqCst);
}
async fn fetch_filtered(&self) -> Result<Vec<Tool>, Error> {
Ok(apply_filters(
self.backend.list_tools().await?,
self.include.as_deref(),
self.exclude.as_deref(),
))
}
}
fn apply_filters(
tools: Vec<Tool>,
include: Option<&[String]>,
exclude: Option<&[String]>,
) -> Vec<Tool> {
let include = include.map(|values| values.iter().collect::<HashSet<_>>());
let exclude = exclude.map(|values| values.iter().collect::<HashSet<_>>());
tools
.into_iter()
.filter(|tool| include.as_ref().is_none_or(|include| include.contains(&tool.name)))
.filter(|tool| exclude.as_ref().is_none_or(|exclude| !exclude.contains(&tool.name)))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Clone)]
struct MockBackend {
tools: Vec<Tool>,
call_count: Arc<AtomicU32>,
}
impl MockBackend {
fn new(tools: Vec<Tool>) -> Self {
Self { tools, call_count: Arc::new(AtomicU32::new(0)) }
}
fn call_count(&self) -> u32 {
self.call_count.load(Ordering::SeqCst)
}
}
impl ToolBackend for MockBackend {
async fn list_tools(&self) -> Result<Vec<Tool>, Error> {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(self.tools.clone())
}
}
fn make_tool(name: &str) -> Tool {
Tool::new(name, None::<String>, json!({ "type": "object", "properties": {} }))
}
#[tokio::test]
async fn new_cache_is_not_populated() {
let backend = MockBackend::new(vec![]);
let cache = ToolCache::new(backend, None, None);
assert!(!cache.is_populated());
}
#[tokio::test]
async fn get_all_fetches_from_backend_on_first_call() {
let backend = MockBackend::new(vec![make_tool("fetch")]);
let call_count = backend.call_count.clone();
let cache = ToolCache::new(backend, None, None);
let _ = cache.get_all().await.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn get_all_returns_expected_tools() {
let backend = MockBackend::new(vec![make_tool("fetch"), make_tool("search")]);
let cache = ToolCache::new(backend, None, None);
let tools = cache.get_all().await.unwrap();
assert_eq!(tools.len(), 2);
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"fetch"));
assert!(names.contains(&"search"));
}
#[tokio::test]
async fn cache_is_populated_after_first_get_all() {
let backend = MockBackend::new(vec![make_tool("fetch")]);
let cache = ToolCache::new(backend, None, None);
let _ = cache.get_all().await.unwrap();
assert!(cache.is_populated());
}
#[tokio::test]
async fn get_all_uses_cache_on_subsequent_calls() {
let backend = MockBackend::new(vec![make_tool("fetch")]);
let call_count = backend.call_count.clone();
let cache = ToolCache::new(backend, None, None);
let _ = cache.get_all().await.unwrap();
let _ = cache.get_all().await.unwrap();
let _ = cache.get_all().await.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 1, "backend called more than once");
}
#[tokio::test]
async fn get_returns_some_for_known_tool() {
let backend = MockBackend::new(vec![make_tool("fetch")]);
let cache = ToolCache::new(backend, None, None);
let tool = cache.get("fetch").await.unwrap();
assert!(tool.is_some());
assert_eq!(tool.unwrap().name, "fetch");
}
#[tokio::test]
async fn get_returns_none_for_unknown_tool() {
let backend = MockBackend::new(vec![make_tool("fetch")]);
let cache = ToolCache::new(backend, None, None);
let tool = cache.get("nonexistent").await.unwrap();
assert!(tool.is_none());
}
#[tokio::test]
async fn refresh_forces_re_fetch() {
let backend = MockBackend::new(vec![make_tool("fetch")]);
let call_count = backend.call_count.clone();
let cache = ToolCache::new(backend, None, None);
let _ = cache.get_all().await.unwrap(); cache.refresh().await.unwrap(); assert_eq!(call_count.load(Ordering::SeqCst), 2, "expected 2 backend calls after refresh");
}
#[tokio::test]
async fn invalidate_clears_cache() {
let backend = MockBackend::new(vec![make_tool("fetch")]);
let call_count = backend.call_count.clone();
let cache = ToolCache::new(backend, None, None);
let _ = cache.get_all().await.unwrap(); cache.invalidate();
assert!(!cache.is_populated());
let _ = cache.get_all().await.unwrap(); assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn include_filter_keeps_only_named_tools() {
let backend =
MockBackend::new(vec![make_tool("fetch"), make_tool("search"), make_tool("upload")]);
let cache = ToolCache::new(backend, Some(vec!["fetch".into()]), None);
let tools = cache.get_all().await.unwrap();
assert_eq!(tools.len(), 1, "expected only 'fetch'");
assert_eq!(tools[0].name, "fetch");
}
#[tokio::test]
async fn exclude_filter_removes_named_tools() {
let backend =
MockBackend::new(vec![make_tool("fetch"), make_tool("search"), make_tool("upload")]);
let cache = ToolCache::new(backend, None, Some(vec!["search".into()]));
let tools = cache.get_all().await.unwrap();
assert_eq!(tools.len(), 2, "expected 'fetch' and 'upload'");
assert!(tools.iter().all(|t| t.name != "search"));
}
#[tokio::test]
async fn include_then_exclude_applied_in_order() {
let backend =
MockBackend::new(vec![make_tool("fetch"), make_tool("search"), make_tool("upload")]);
let cache = ToolCache::new(
backend,
Some(vec!["fetch".into(), "search".into()]),
Some(vec!["search".into()]),
);
let tools = cache.get_all().await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "fetch");
}
#[tokio::test]
async fn include_filter_no_matches_yields_empty() {
let backend = MockBackend::new(vec![make_tool("fetch")]);
let cache = ToolCache::new(backend, Some(vec!["nonexistent".into()]), None);
let tools = cache.get_all().await.unwrap();
assert!(tools.is_empty());
}
#[tokio::test]
async fn exclude_filter_all_tools_yields_empty() {
let backend = MockBackend::new(vec![make_tool("fetch"), make_tool("search")]);
let cache =
ToolCache::new(backend, None, Some(vec!["fetch".into(), "search".into()]));
let tools = cache.get_all().await.unwrap();
assert!(tools.is_empty());
}
#[tokio::test]
async fn empty_backend_yields_empty_list() {
let backend = MockBackend::new(vec![]);
let cache = ToolCache::new(backend, None, None);
let tools = cache.get_all().await.unwrap();
assert!(tools.is_empty());
}
}