use crate::result::ToolResult;
use crate::tool::Tool;
use erio_core::ToolError;
use serde_json::Value;
use std::collections::HashMap;
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register<T: Tool + 'static>(&mut self, tool: T) {
let name = tool.name().to_string();
self.tools.insert(name, Box::new(tool));
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
self.tools.get(name).map(AsRef::as_ref)
}
pub fn contains(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn list(&self) -> Vec<&str> {
self.tools.keys().map(String::as_str).collect()
}
pub async fn execute(&self, name: &str, params: Value) -> Result<ToolResult, ToolError> {
let tool = self
.get(name)
.ok_or_else(|| ToolError::NotFound(name.into()))?;
tool.execute(params).await
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}