use std::sync::Arc;
use bamboo_agent_core::{RegistryError, SharedTool, Tool, ToolSchema};
use dashmap::DashMap;
use crate::guide::{ToolGuide, ToolGuideSpec};
pub struct ToolRegistry {
tools: bamboo_agent_core::ToolRegistry,
guides: DashMap<String, Arc<dyn ToolGuide>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: bamboo_agent_core::ToolRegistry::new(),
guides: DashMap::new(),
}
}
pub fn register<T>(&self, tool: T) -> Result<(), RegistryError>
where
T: Tool + 'static,
{
self.tools.register(tool)
}
pub fn register_with_guide<T, G>(&self, tool: T, guide: G) -> Result<(), RegistryError>
where
T: Tool + 'static,
G: ToolGuide + 'static,
{
let name = tool.name().to_string();
self.tools.register(tool)?;
self.guides.insert(name, Arc::new(guide));
Ok(())
}
pub fn register_guide<G>(&self, tool_name: &str, guide: G) -> Result<(), RegistryError>
where
G: ToolGuide + 'static,
{
if !self.tools.contains(tool_name) {
return Err(RegistryError::InvalidTool(format!(
"tool '{}' not found, register tool before adding guide",
tool_name
)));
}
self.guides.insert(tool_name.to_string(), Arc::new(guide));
Ok(())
}
pub fn register_guide_from_json(
&self,
tool_name: &str,
json_spec: &str,
) -> Result<(), RegistryError> {
let spec = ToolGuideSpec::from_json_str(json_spec)
.map_err(|e| RegistryError::InvalidTool(format!("invalid guide JSON: {}", e)))?;
self.register_guide(tool_name, spec)
}
pub fn register_guide_from_yaml(
&self,
tool_name: &str,
yaml_spec: &str,
) -> Result<(), RegistryError> {
let spec = ToolGuideSpec::from_yaml_str(yaml_spec)
.map_err(|e| RegistryError::InvalidTool(format!("invalid guide YAML: {}", e)))?;
self.register_guide(tool_name, spec)
}
pub fn get(&self, name: &str) -> Option<SharedTool> {
self.tools.get(name)
}
pub fn get_guide(&self, name: &str) -> Option<Arc<dyn ToolGuide>> {
self.guides.get(name).map(|entry| Arc::clone(&entry))
}
pub fn contains(&self, name: &str) -> bool {
self.tools.contains(name)
}
pub fn list_tools(&self) -> Vec<ToolSchema> {
self.tools.list_tools()
}
pub fn list_tool_names(&self) -> Vec<String> {
self.tools.list_tool_names()
}
pub fn unregister(&self, name: &str) -> bool {
self.guides.remove(name);
self.tools.unregister(name)
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn clear(&self) {
self.guides.clear();
self.tools.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::ReadTool;
struct MockGuide;
impl ToolGuide for MockGuide {
fn tool_name(&self) -> &str {
"mock_tool"
}
fn when_to_use(&self) -> &str {
"when you need to mock"
}
fn when_not_to_use(&self) -> &str {
"in production"
}
fn examples(&self) -> Vec<crate::guide::ToolExample> {
vec![]
}
fn related_tools(&self) -> Vec<&str> {
vec![]
}
fn category(&self) -> crate::guide::ToolCategory {
crate::guide::ToolCategory::FileReading
}
}
#[test]
fn register_tool_without_guide() {
let registry = ToolRegistry::new();
registry.register(ReadTool::new()).unwrap();
assert!(registry.contains("Read"));
assert!(registry.get_guide("Read").is_none());
}
#[test]
fn register_tool_with_guide() {
let registry = ToolRegistry::new();
registry
.register_with_guide(ReadTool::new(), MockGuide)
.unwrap();
assert!(registry.contains("Read"));
assert!(registry.get_guide("Read").is_some());
}
#[test]
fn register_guide_from_json() {
let registry = ToolRegistry::new();
registry.register(ReadTool::new()).unwrap();
let json_spec = r#"{
"tool_name": "Read",
"when_to_use": "Read small files",
"when_not_to_use": "Don't read large files",
"examples": [],
"related_tools": [],
"category": "FileReading"
}"#;
registry
.register_guide_from_json("Read", json_spec)
.unwrap();
let guide = registry.get_guide("Read").unwrap();
assert_eq!(guide.when_to_use(), "Read small files");
}
#[test]
fn guide_removed_when_tool_unregistered() {
let registry = ToolRegistry::new();
registry
.register_with_guide(ReadTool::new(), MockGuide)
.unwrap();
registry.unregister("Read");
assert!(!registry.contains("Read"));
assert!(registry.get_guide("Read").is_none());
}
}