use std::collections::HashMap;
use std::time::Instant;
use serde_json::Value;
use tracing::{error, info};
use crate::error::{Result, ZeptoError};
use crate::providers::ToolDefinition;
use super::{Tool, ToolContext};
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: Box<dyn Tool>) {
let name = tool.name().to_string();
info!(tool = %name, "Registering tool");
self.tools.insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
self.tools.get(name).map(|t| t.as_ref())
}
pub async fn execute(&self, name: &str, args: Value) -> Result<String> {
self.execute_with_context(name, args, &ToolContext::default())
.await
}
pub async fn execute_with_context(
&self,
name: &str,
args: Value,
ctx: &ToolContext,
) -> Result<String> {
let tool = self
.tools
.get(name)
.ok_or_else(|| ZeptoError::NotFound(format!("Tool not found: {}", name)))?;
let start = Instant::now();
match tool.execute(args, ctx).await {
Ok(result) => {
info!(
tool = name,
duration_ms = start.elapsed().as_millis() as u64,
"Tool executed successfully"
);
Ok(result)
}
Err(e) => {
error!(
tool = name,
error = %e,
duration_ms = start.elapsed().as_millis() as u64,
"Tool execution failed"
);
Err(e)
}
}
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|t| ToolDefinition {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.parameters(),
})
.collect()
}
pub fn names(&self) -> Vec<&str> {
self.tools.keys().map(|s| s.as_str()).collect()
}
pub fn has(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::EchoTool;
use serde_json::json;
#[test]
fn test_registry_new() {
let registry = ToolRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
#[test]
fn test_registry_default() {
let registry = ToolRegistry::default();
assert!(registry.is_empty());
}
#[test]
fn test_registry_register() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
assert!(registry.has("echo"));
assert_eq!(registry.len(), 1);
assert!(!registry.is_empty());
}
#[test]
fn test_registry_get() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
let tool = registry.get("echo");
assert!(tool.is_some());
assert_eq!(tool.unwrap().name(), "echo");
let missing = registry.get("nonexistent");
assert!(missing.is_none());
}
#[tokio::test]
async fn test_registry_register_and_execute() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
assert!(registry.has("echo"));
let result = registry.execute("echo", json!({"message": "hello"})).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "hello");
}
#[tokio::test]
async fn test_registry_execute_with_context() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
let ctx = ToolContext::new()
.with_channel("telegram", "123456")
.with_workspace("/tmp/test");
let result = registry
.execute_with_context("echo", json!({"message": "world"}), &ctx)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "world");
}
#[test]
fn test_registry_definitions() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
let definitions = registry.definitions();
assert_eq!(definitions.len(), 1);
assert_eq!(definitions[0].name, "echo");
assert_eq!(
definitions[0].description,
"Echoes back the provided message"
);
assert!(definitions[0].parameters.is_object());
}
#[test]
fn test_registry_names() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
let names = registry.names();
assert_eq!(names.len(), 1);
assert!(names.contains(&"echo"));
}
#[tokio::test]
async fn test_tool_not_found() {
let registry = ToolRegistry::new();
let result = registry.execute("nonexistent", json!({})).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, ZeptoError::NotFound(_)));
assert!(err.to_string().contains("Tool not found: nonexistent"));
}
#[tokio::test]
async fn test_registry_execute_missing_message() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
let result = registry.execute("echo", json!({})).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "(no message)");
}
#[tokio::test]
async fn test_registry_execute_null_message() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
let result = registry.execute("echo", json!({"message": null})).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "(no message)");
}
#[test]
fn test_registry_replace_tool() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool));
registry.register(Box::new(EchoTool));
assert_eq!(registry.len(), 1);
assert!(registry.has("echo"));
}
}