use lru::LruCache;
use rustc_hash::FxHashMap;
use std::num::NonZeroUsize;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tracing::error;
use super::McpToolInfo;
use super::tool_discovery::DetailLevel;
#[derive(Clone)]
pub struct BloomFilter {
bits: Vec<bool>,
num_hashes: usize,
size: usize,
}
impl BloomFilter {
pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
let size = Self::optimal_size(expected_items, false_positive_rate);
let num_hashes = Self::optimal_num_hashes(size, expected_items);
Self {
bits: vec![false; size],
num_hashes,
size,
}
}
pub fn insert(&mut self, item: &str) {
for i in 0..self.num_hashes {
let hash = self.hash(item, i);
let index = hash % self.size;
self.bits[index] = true;
}
}
pub fn contains(&self, item: &str) -> bool {
for i in 0..self.num_hashes {
let hash = self.hash(item, i);
let index = hash % self.size;
if !self.bits[index] {
return false;
}
}
true
}
pub fn clear(&mut self) {
self.bits.fill(false);
}
fn optimal_size(expected_items: usize, false_positive_rate: f64) -> usize {
let size = -(expected_items as f64 * false_positive_rate.ln() / (2.0_f64.ln().powi(2)));
size.ceil() as usize
}
fn optimal_num_hashes(size: usize, expected_items: usize) -> usize {
let num_hashes = (size as f64 / expected_items as f64) * 2.0_f64.ln();
num_hashes.ceil() as usize
}
fn hash(&self, item: &str, seed: usize) -> usize {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
item.hash(&mut hasher);
seed.hash(&mut hasher);
hasher.finish() as usize
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct ToolDiscoveryCacheKey {
provider_name: String,
keyword: String,
detail_level: DetailLevel,
}
#[derive(Clone)]
struct CachedToolDiscoveryEntry {
results: Arc<Vec<ToolDiscoveryResult>>,
timestamp: Instant,
}
struct DiscoveryCacheInner {
bloom_filter: BloomFilter,
detailed_cache: LruCache<ToolDiscoveryCacheKey, CachedToolDiscoveryEntry>,
all_tools_cache: FxHashMap<String, Vec<McpToolInfo>>,
last_refresh: FxHashMap<String, Instant>,
}
#[derive(Debug, Clone)]
pub struct ToolDiscoveryResult {
pub tool: McpToolInfo,
pub relevance_score: f64,
pub detail_level: DetailLevel,
}
pub struct ToolDiscoveryCache {
inner: Arc<RwLock<DiscoveryCacheInner>>,
config: CacheConfig,
}
#[derive(Clone)]
struct CacheConfig {
max_age: Duration,
provider_refresh_interval: Duration,
expected_tool_count: usize,
false_positive_rate: f64,
}
impl ToolDiscoveryCache {
pub fn new(capacity: usize) -> Self {
let config = CacheConfig {
max_age: Duration::from_secs(300), provider_refresh_interval: Duration::from_secs(60), expected_tool_count: 1000,
false_positive_rate: 0.01, };
let bloom_filter = BloomFilter::new(config.expected_tool_count, config.false_positive_rate);
let cache_size = NonZeroUsize::new(capacity).or(NonZeroUsize::new(100));
Self {
inner: Arc::new(RwLock::new(DiscoveryCacheInner {
bloom_filter,
detailed_cache: LruCache::new(cache_size.unwrap_or(NonZeroUsize::MIN)),
all_tools_cache: FxHashMap::default(),
last_refresh: FxHashMap::default(),
})),
config,
}
}
pub fn might_have_tool(&self, tool_name: &str) -> bool {
match self.inner.read() {
Ok(inner) => inner.bloom_filter.contains(tool_name),
Err(_) => {
tracing::warn!("Bloom filter lock poisoned, assuming tool might exist");
true
}
}
}
pub fn get_cached_discovery(
&self,
provider_name: &str,
keyword: &str,
detail_level: DetailLevel,
) -> Option<Arc<Vec<ToolDiscoveryResult>>> {
let key = ToolDiscoveryCacheKey {
provider_name: provider_name.to_owned(),
keyword: keyword.to_owned(),
detail_level,
};
let mut inner = match self.inner.write() {
Ok(inner) => inner,
Err(e) => {
tracing::error!("Detailed cache lock poisoned: {}", e);
return None;
}
};
if let Some(cached) = inner.detailed_cache.get(&key) {
if cached.timestamp.elapsed() < self.config.max_age {
return Some(Arc::clone(&cached.results));
} else {
inner.detailed_cache.pop(&key);
}
}
None
}
pub fn cache_discovery(
&self,
provider_name: &str,
keyword: &str,
detail_level: DetailLevel,
results: Vec<ToolDiscoveryResult>,
) {
self.cache_discovery_shared(provider_name, keyword, detail_level, Arc::new(results));
}
fn cache_discovery_shared(
&self,
provider_name: &str,
keyword: &str,
detail_level: DetailLevel,
results: Arc<Vec<ToolDiscoveryResult>>,
) {
let key = ToolDiscoveryCacheKey {
provider_name: provider_name.to_owned(),
keyword: keyword.to_owned(),
detail_level,
};
let cached = CachedToolDiscoveryEntry {
results: Arc::clone(&results),
timestamp: Instant::now(),
};
let Ok(mut inner) = self.inner.write() else {
tracing::error!("Failed to acquire discovery cache lock for writing");
return;
};
inner.detailed_cache.put(key, cached);
for result in results.iter() {
inner.bloom_filter.insert(&result.tool.name);
}
}
pub fn get_all_tools(
&self,
provider_name: &str,
refresh_if_stale: bool,
) -> Option<Vec<McpToolInfo>> {
let inner = match self.inner.read() {
Ok(inner) => inner,
Err(e) => {
error!("Discovery cache lock poisoned: {}", e);
return None;
}
};
let should_refresh = if let Some(last) = inner.last_refresh.get(provider_name) {
last.elapsed() > self.config.provider_refresh_interval
} else {
true
};
if should_refresh && refresh_if_stale {
return None; }
inner.all_tools_cache.get(provider_name).cloned()
}
pub fn cache_all_tools(&self, provider_name: &str, tools: Vec<McpToolInfo>) {
let mut inner = match self.inner.write() {
Ok(inner) => inner,
Err(e) => {
tracing::error!("Discovery cache lock poisoned: {}", e);
return;
}
};
inner
.all_tools_cache
.insert(provider_name.to_owned(), tools.clone());
inner
.last_refresh
.insert(provider_name.to_owned(), Instant::now());
inner.bloom_filter.clear();
let all_tool_names: Vec<String> = inner
.all_tools_cache
.values()
.flat_map(|provider_tools| provider_tools.iter().map(|tool| tool.name.clone()))
.collect();
for tool_name in all_tool_names {
inner.bloom_filter.insert(&tool_name);
}
}
pub fn cache_tool_result(&self, _cache_key: String, _result: serde_json::Value) {
}
pub fn clear(&self) {
if let Ok(mut inner) = self.inner.write() {
inner.bloom_filter.clear();
inner.detailed_cache.clear();
inner.all_tools_cache.clear();
inner.last_refresh.clear();
}
}
pub fn stats(&self) -> ToolCacheStats {
let (detailed_entries, detailed_capacity, all_tools_entries, bf_size, bf_hashes) = self
.inner
.read()
.map(|inner| {
(
inner.detailed_cache.len(),
inner.detailed_cache.cap().get(),
inner.all_tools_cache.len(),
inner.bloom_filter.size,
inner.bloom_filter.num_hashes,
)
})
.unwrap_or((0, 0, 0, 0, 0));
ToolCacheStats {
detailed_cache_entries: detailed_entries,
detailed_cache_capacity: detailed_capacity,
all_tools_cache_entries: all_tools_entries,
bloom_filter_size: bf_size,
bloom_filter_hashes: bf_hashes,
}
}
}
#[derive(Debug, Clone)]
pub struct ToolCacheStats {
pub detailed_cache_entries: usize,
pub detailed_cache_capacity: usize,
pub all_tools_cache_entries: usize,
pub bloom_filter_size: usize,
pub bloom_filter_hashes: usize,
}
pub struct CachedToolDiscovery {
cache: Arc<ToolDiscoveryCache>,
}
impl CachedToolDiscovery {
pub fn new(cache_capacity: usize) -> Self {
Self {
cache: Arc::new(ToolDiscoveryCache::new(cache_capacity)),
}
}
pub fn search_tools(
&self,
provider_name: &str,
keyword: &str,
detail_level: DetailLevel,
all_tools: Vec<McpToolInfo>,
) -> Arc<Vec<ToolDiscoveryResult>> {
if !self.cache.might_have_tool(keyword) && !keyword.is_empty() {
return Arc::new(Vec::new());
}
if let Some(cached) = self
.cache
.get_cached_discovery(provider_name, keyword, detail_level)
{
return cached;
}
let results = Arc::new(self.perform_search(&all_tools, keyword, detail_level));
self.cache.cache_discovery_shared(
provider_name,
keyword,
detail_level,
Arc::clone(&results),
);
results
}
pub fn get_all_tools_cached(
&self,
provider_name: &str,
all_tools: Vec<McpToolInfo>,
) -> Vec<McpToolInfo> {
if let Some(cached) = self.cache.get_all_tools(provider_name, true) {
return cached;
}
self.cache.cache_all_tools(provider_name, all_tools.clone());
all_tools
}
fn perform_search(
&self,
tools: &[McpToolInfo],
keyword: &str,
detail_level: DetailLevel,
) -> Vec<ToolDiscoveryResult> {
let keyword_lower = keyword.to_lowercase();
let mut results = Vec::new();
for tool in tools {
let relevance_score = self.calculate_relevance(tool, &keyword_lower);
if relevance_score > 0.0 {
let result = ToolDiscoveryResult {
tool: tool.clone(),
relevance_score,
detail_level,
};
results.push(result);
}
}
results.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
fn calculate_relevance(&self, tool: &McpToolInfo, keyword: &str) -> f64 {
let name_lower = tool.name.to_lowercase();
let description_lower = tool.description.to_lowercase();
let mut score: f64 = 0.0;
if name_lower == keyword {
score += 1.0;
}
else if name_lower.starts_with(keyword) {
score += 0.8;
}
else if name_lower.contains(keyword) {
score += 0.6;
}
if description_lower.contains(keyword) {
score += 0.3;
}
let schema_str = serde_json::to_string(&tool.input_schema)
.unwrap_or_default()
.to_lowercase();
if schema_str.contains(keyword) {
score += 0.2;
}
score.min(1.0)
}
pub fn stats(&self) -> ToolCacheStats {
self.cache.stats()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bloom_filter() {
let mut filter = BloomFilter::new(100, 0.01);
filter.insert("tool1");
filter.insert("tool2");
filter.insert("tool3");
assert!(filter.contains("tool1"));
assert!(filter.contains("tool2"));
assert!(filter.contains("tool3"));
assert!(!filter.contains("tool4"));
}
#[test]
fn test_cache_key_equality() {
let key1 = ToolDiscoveryCacheKey {
provider_name: "test".to_string(),
keyword: "search".to_string(),
detail_level: DetailLevel::Full,
};
let key2 = ToolDiscoveryCacheKey {
provider_name: "test".to_string(),
keyword: "search".to_string(),
detail_level: DetailLevel::Full,
};
assert_eq!(key1, key2);
}
#[test]
fn test_tool_discovery_cache() {
let cache = ToolDiscoveryCache::new(10);
let provider_name = "test_provider";
let keyword = "search";
let detail_level = DetailLevel::Full;
assert!(
cache
.get_cached_discovery(provider_name, keyword, detail_level)
.is_none()
);
let results = vec![ToolDiscoveryResult {
tool: McpToolInfo {
name: "search_files".to_string(),
description: "Search for files".to_string(),
provider: "test".to_string(),
input_schema: serde_json::json!({}),
},
relevance_score: 0.9,
detail_level,
}];
cache.cache_discovery(provider_name, keyword, detail_level, results.clone());
let cached = cache.get_cached_discovery(provider_name, keyword, detail_level);
assert!(cached.is_some());
assert_eq!(cached.unwrap().len(), 1);
}
}