pub mod definitions;
pub use definitions::create_default_registry;
use crate::types::Tool;
use anyhow::Result;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolCategory {
Core,
Extended,
}
pub struct ToolRegistry {
core_tools: Vec<Tool>,
extended_tools: HashMap<String, Tool>,
session_loaded: Arc<RwLock<HashMap<String, Tool>>>,
usage_count: Arc<RwLock<HashMap<String, usize>>>,
}
impl ToolRegistry {
pub fn new(core_tools: Vec<Tool>, extended_tools: HashMap<String, Tool>) -> Self {
let core_count = core_tools.len();
let extended_count = extended_tools.len();
info!(
"ToolRegistry initialized: {} core tools, {} extended tools",
core_count, extended_count
);
let core_names: std::collections::HashSet<_> =
core_tools.iter().map(|t| t.name.clone()).collect();
for extended_name in extended_tools.keys() {
if core_names.contains(extended_name) {
warn!(
"Tool name conflict: '{}' in both core and extended tools",
extended_name
);
}
}
Self {
core_tools,
extended_tools,
session_loaded: Arc::new(RwLock::new(HashMap::new())),
usage_count: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get_core_tools(&self) -> Vec<Tool> {
self.core_tools.clone()
}
pub fn list_tool_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.core_tools.iter().map(|t| t.name.clone()).collect();
let extended_names: Vec<String> = self.extended_tools.keys().cloned().collect();
names.extend(extended_names);
names
}
pub fn get_loaded_tools(&self) -> Vec<Tool> {
let loaded = self.session_loaded.read();
let mut tools = self.core_tools.clone();
tools.extend(loaded.values().cloned());
tools.sort_by(|a, b| self.cmp_tools_by_usage(a, b));
tools
}
pub async fn load_tool(&self, name: &str) -> Option<Tool> {
if let Some(tool) = self.core_tools.iter().find(|t| t.name == name) {
self.track_usage(name);
return Some(tool.clone());
}
{
let loaded = self.session_loaded.read();
if let Some(tool) = loaded.get(name) {
self.track_usage(name);
return Some(tool.clone());
}
}
if let Some(tool) = self.extended_tools.get(name) {
debug!("Loading extended tool: {}", name);
let mut loaded = self.session_loaded.write();
loaded.insert(name.to_string(), tool.clone());
self.track_usage(name);
return Some(tool.clone());
}
warn!("Tool not found: {}", name);
None
}
pub async fn load_tools(&self, names: &[String]) -> Vec<Tool> {
let mut tools = Vec::new();
for name in names {
if let Some(tool) = self.load_tool(name).await {
tools.push(tool);
}
}
tools
}
pub fn tool_exists(&self, name: &str) -> bool {
self.core_tools.iter().any(|t| t.name == name) || self.extended_tools.contains_key(name)
}
pub fn total_tool_count(&self) -> usize {
self.core_tools.len() + self.extended_tools.len()
}
pub fn loaded_tool_count(&self) -> usize {
let loaded = self.session_loaded.read();
self.core_tools.len() + loaded.len()
}
pub fn get_tool_usage(&self, name: &str) -> usize {
let usage = self.usage_count.read();
*usage.get(name).unwrap_or(&0)
}
pub fn get_usage_stats(&self) -> HashMap<String, usize> {
let usage = self.usage_count.read();
usage.clone()
}
pub fn clear_session_cache(&self) {
let mut loaded = self.session_loaded.write();
let cleared_count = loaded.len();
loaded.clear();
debug!("Cleared session cache ({} tools)", cleared_count);
}
pub fn add_tool(&self, tool: Tool) -> Result<()> {
let name = tool.name.clone();
if self.tool_exists(&name) {
anyhow::bail!("Tool with name '{}' already exists", name);
}
let mut loaded = self.session_loaded.write();
info!("Adding custom tool: {}", name);
loaded.insert(name, tool);
Ok(())
}
pub fn remove_tool(&self, name: &str) -> Result<()> {
if self.core_tools.iter().any(|t| t.name == name) {
anyhow::bail!("Cannot remove core tool '{}'", name);
}
let mut loaded = self.session_loaded.write();
if loaded.remove(name).is_none() && !self.extended_tools.contains_key(name) {
anyhow::bail!("Tool '{}' not found", name);
}
info!("Removed tool: {}", name);
Ok(())
}
fn track_usage(&self, name: &str) {
let mut usage = self.usage_count.write();
*usage.entry(name.to_string()).or_insert(0) += 1;
}
fn cmp_tools_by_usage(&self, a: &Tool, b: &Tool) -> std::cmp::Ordering {
let usage_a = self.get_tool_usage(&a.name);
let usage_b = self.get_tool_usage(&b.name);
usage_b.cmp(&usage_a) }
}
pub fn create_tool_registry() -> ToolRegistry {
create_default_registry()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_creation() {
let registry = create_tool_registry();
assert!(registry.total_tool_count() > 8);
assert_eq!(registry.get_core_tools().len(), 8);
}
#[test]
fn test_core_tools_always_available() {
let registry = create_tool_registry();
let core = registry.get_core_tools();
assert!(core.iter().any(|t| t.name == "query_memory"));
assert!(core.iter().any(|t| t.name == "health_check"));
}
#[test]
fn test_load_extended_tool() {
let registry = create_tool_registry();
let extended_name = registry
.extended_tools
.keys()
.next()
.expect("Should have extended tools");
assert_eq!(registry.loaded_tool_count(), 8);
let rt = tokio::runtime::Runtime::new().unwrap();
let tool = rt.block_on(registry.load_tool(extended_name));
assert!(tool.is_some());
assert_eq!(registry.loaded_tool_count(), 9);
}
#[test]
fn test_tool_usage_tracking() {
let registry = create_tool_registry();
assert_eq!(registry.get_tool_usage("query_memory"), 0);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(registry.load_tool("query_memory"));
assert_eq!(registry.get_tool_usage("query_memory"), 1);
}
#[test]
fn test_clear_session_cache() {
let registry = create_tool_registry();
let rt = tokio::runtime::Runtime::new().unwrap();
let extended_name = registry
.extended_tools
.keys()
.next()
.expect("Should have extended tools");
rt.block_on(registry.load_tool(extended_name));
assert_eq!(registry.loaded_tool_count(), 9);
registry.clear_session_cache();
assert_eq!(registry.loaded_tool_count(), 8);
}
#[test]
fn test_list_tool_names() {
let registry = create_tool_registry();
let names = registry.list_tool_names();
assert!(names.len() >= 8); assert!(names.contains(&"query_memory".to_string()));
assert!(names.contains(&"health_check".to_string()));
}
#[test]
fn test_list_tool_names_vs_full_schema() {
let registry = create_tool_registry();
let names = registry.list_tool_names();
let total = registry.total_tool_count();
assert_eq!(names.len(), total);
}
}