use crate::core::tool_loader::ToolResource;
use anyhow::{anyhow, Result};
use serde_json::Value;
use std::collections::HashMap;
use tracing::{debug, warn};
#[derive(Debug, Clone, Default)]
pub struct ToolRegistry {
tools: HashMap<String, ToolResource>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: ToolResource) {
let slug = tool.slug.clone();
if self.tools.contains_key(&slug) {
warn!("Tool resource '{}' already registered, overwriting", slug);
}
debug!("Registered tool resource: {}", slug);
self.tools.insert(slug, tool);
}
pub fn register_all(&mut self, tools: Vec<ToolResource>) {
for tool in tools {
self.register(tool);
}
}
pub fn get(&self, slug: &str) -> Option<&ToolResource> {
self.tools.get(slug)
}
pub fn _contains(&self, slug: &str) -> bool {
self.tools.contains_key(slug)
}
pub fn _list_slugs(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
pub fn resolve_tools(&self, slugs: &[String]) -> Result<Vec<Value>> {
let mut merged_tools: HashMap<String, Value> = HashMap::new();
for slug in slugs {
let tool_resource = self
.tools
.get(slug)
.ok_or_else(|| anyhow!("Tool resource '{}' not found", slug))?;
debug!(
"Resolving tools from '{}': {} tool(s)",
slug,
tool_resource.tools.len()
);
for tool_value in &tool_resource.tools {
if let Some(function_name) = tool_value
.get("function")
.and_then(|f| f.get("name"))
.and_then(|n| n.as_str())
{
if merged_tools.contains_key(function_name) {
debug!(
"Tool function '{}' already exists, overwriting with version from '{}'",
function_name, slug
);
}
merged_tools.insert(function_name.to_string(), tool_value.clone());
} else {
let key = format!("tool_{}", merged_tools.len());
merged_tools.insert(key, tool_value.clone());
}
}
}
let result: Vec<Value> = merged_tools.into_values().collect();
debug!(
"Resolved {} tool(s) from {} slug(s)",
result.len(),
slugs.len()
);
Ok(result)
}
pub fn count(&self) -> usize {
self.tools.len()
}
pub fn _clear(&mut self) {
self.tools.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn create_test_tool(slug: &str, tool_name: &str) -> ToolResource {
ToolResource {
slug: slug.to_string(),
name: format!("{} Tools", slug),
description: Some(format!("Test tools for {}", slug)),
tools: vec![json!({
"type": "function",
"function": {
"name": tool_name,
"description": format!("Test function {}", tool_name)
}
})],
}
}
#[test]
fn test_register_and_get() {
let mut registry = ToolRegistry::new();
let tool = create_test_tool("web-tools", "fetch_url");
registry.register(tool.clone());
assert!(registry._contains("web-tools"));
assert_eq!(registry.get("web-tools").unwrap().slug, "web-tools");
assert_eq!(registry.count(), 1);
}
#[test]
fn test_resolve_single_tool() {
let mut registry = ToolRegistry::new();
registry.register(create_test_tool("web-tools", "fetch_url"));
let resolved = registry.resolve_tools(&["web-tools".to_string()]).unwrap();
assert_eq!(resolved.len(), 1);
}
#[test]
fn test_resolve_multiple_tools() {
let mut registry = ToolRegistry::new();
registry.register(create_test_tool("web-tools", "fetch_url"));
registry.register(create_test_tool("data-tools", "calculate"));
let resolved = registry
.resolve_tools(&["web-tools".to_string(), "data-tools".to_string()])
.unwrap();
assert_eq!(resolved.len(), 2);
}
#[test]
fn test_resolve_nonexistent_tool() {
let registry = ToolRegistry::new();
let result = registry.resolve_tools(&["nonexistent".to_string()]);
assert!(result.is_err());
}
#[test]
fn test_deduplication() {
let mut registry = ToolRegistry::new();
let tool1 = ToolResource {
slug: "web-tools-v1".to_string(),
name: "Web Tools V1".to_string(),
description: None,
tools: vec![json!({
"type": "function",
"function": {
"name": "fetch_url",
"description": "Old version"
}
})],
};
let tool2 = ToolResource {
slug: "web-tools-v2".to_string(),
name: "Web Tools V2".to_string(),
description: None,
tools: vec![json!({
"type": "function",
"function": {
"name": "fetch_url",
"description": "New version"
}
})],
};
registry.register(tool1);
registry.register(tool2);
let resolved = registry
.resolve_tools(&["web-tools-v1".to_string(), "web-tools-v2".to_string()])
.unwrap();
assert_eq!(resolved.len(), 1);
assert_eq!(
resolved[0]["function"]["description"].as_str().unwrap(),
"New version"
);
}
}