use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::{ChatResponse, ToolCall, ToolDefinition};
use super::erased_tool::{ErasedTool, ToolWrapper};
use super::tool_error::{ToolError, ToolResult};
use super::tool_trait::Tool;
#[derive(Clone)]
pub struct ToolRegistry {
tools: Arc<RwLock<HashMap<String, Arc<dyn ErasedTool>>>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register<T: Tool + 'static>(&mut self, tool: T) {
let name = tool.name().to_string();
let wrapper = Arc::new(ToolWrapper::new(tool));
self.tools.write().unwrap().insert(name, wrapper);
}
pub fn len(&self) -> usize {
self.tools.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.tools.read().unwrap().is_empty()
}
pub fn contains(&self, name: &str) -> bool {
self.tools.read().unwrap().contains_key(name)
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools
.read()
.unwrap()
.values()
.map(|t| t.definition())
.collect()
}
pub async fn execute(&self, call: &ToolCall) -> ToolResult<serde_json::Value> {
let func_name = call.function_name().ok_or(ToolError::InvalidToolCall)?;
let args = call
.arguments()
.cloned()
.unwrap_or_else(|| serde_json::json!({}));
let tool = {
let tools = self.tools.read().unwrap();
tools
.get(func_name)
.cloned()
.ok_or_else(|| ToolError::NotFound(func_name.to_string()))?
};
tool.execute_erased(args).await
}
pub async fn execute_all(&self, response: &ChatResponse) -> Vec<ToolResult<serde_json::Value>> {
let Some(calls) = response.tool_calls() else {
return Vec::new();
};
let mut results = Vec::with_capacity(calls.len());
for call in calls {
results.push(self.execute(call).await);
}
results
}
pub fn execute_blocking(&self, call: &ToolCall) -> ToolResult<serde_json::Value> {
let func_name = call.function_name().ok_or(ToolError::InvalidToolCall)?;
let args = call
.arguments()
.cloned()
.unwrap_or_else(|| serde_json::json!({}));
let tool = {
let tools = self.tools.read().unwrap();
tools
.get(func_name)
.cloned()
.ok_or_else(|| ToolError::NotFound(func_name.to_string()))?
};
tool.execute_erased_blocking(args)
}
pub fn execute_all_blocking(
&self,
response: &ChatResponse,
) -> Vec<ToolResult<serde_json::Value>> {
let Some(calls) = response.tool_calls() else {
return Vec::new();
};
calls
.iter()
.map(|call| self.execute_blocking(call))
.collect()
}
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tools = self.tools.read().unwrap();
let names: Vec<_> = tools.keys().collect();
f.debug_struct("ToolRegistry")
.field("tools", &names)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::tool_trait::Tool;
use crate::{ResponseMessage, ToolCallFunction};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, JsonSchema)]
struct MultiplyParams {
a: i32,
b: i32,
}
#[derive(Serialize)]
struct MultiplyResult {
product: i32,
}
struct MultiplyTool;
impl Tool for MultiplyTool {
type Params = MultiplyParams;
type Output = MultiplyResult;
fn name(&self) -> &'static str {
"multiply"
}
fn description(&self) -> &'static str {
"Multiply two numbers"
}
async fn execute(&self, params: Self::Params) -> ToolResult<Self::Output> {
Ok(MultiplyResult {
product: params.a * params.b,
})
}
}
#[derive(Debug, Deserialize, JsonSchema)]
struct GreetParams {
name: String,
}
#[derive(Serialize)]
struct GreetResult {
greeting: String,
}
struct GreetTool;
impl Tool for GreetTool {
type Params = GreetParams;
type Output = GreetResult;
fn name(&self) -> &'static str {
"greet"
}
fn description(&self) -> &'static str {
"Greet someone"
}
async fn execute(&self, params: Self::Params) -> ToolResult<Self::Output> {
Ok(GreetResult {
greeting: format!("Hello, {}!", params.name),
})
}
}
#[test]
fn test_registry_new() {
let registry = ToolRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
#[test]
fn test_registry_register() {
let mut registry = ToolRegistry::new();
registry.register(MultiplyTool);
assert_eq!(registry.len(), 1);
assert!(registry.contains("multiply"));
}
#[test]
fn test_registry_register_multiple() {
let mut registry = ToolRegistry::new();
registry.register(MultiplyTool);
registry.register(GreetTool);
assert_eq!(registry.len(), 2);
assert!(registry.contains("multiply"));
assert!(registry.contains("greet"));
}
#[test]
fn test_registry_definitions() {
let mut registry = ToolRegistry::new();
registry.register(MultiplyTool);
registry.register(GreetTool);
let defs = registry.definitions();
assert_eq!(defs.len(), 2);
let names: Vec<_> = defs.iter().map(|d| d.name()).collect();
assert!(names.contains(&"multiply"));
assert!(names.contains(&"greet"));
}
#[tokio::test]
async fn test_registry_execute() {
let mut registry = ToolRegistry::new();
registry.register(MultiplyTool);
let call = ToolCall::new(ToolCallFunction::with_arguments(
"multiply",
serde_json::json!({"a": 3, "b": 4}),
));
let result = registry.execute(&call).await.unwrap();
assert_eq!(result["product"], 12);
}
#[tokio::test]
async fn test_registry_execute_not_found() {
let registry = ToolRegistry::new();
let call = ToolCall::new(ToolCallFunction::new("unknown"));
let result = registry.execute(&call).await;
assert!(matches!(result, Err(ToolError::NotFound(_))));
}
#[tokio::test]
async fn test_registry_execute_all() {
let mut registry = ToolRegistry::new();
registry.register(MultiplyTool);
registry.register(GreetTool);
let response = ChatResponse {
message: Some(ResponseMessage {
tool_calls: Some(vec![
ToolCall::new(ToolCallFunction::with_arguments(
"multiply",
serde_json::json!({"a": 2, "b": 5}),
)),
ToolCall::new(ToolCallFunction::with_arguments(
"greet",
serde_json::json!({"name": "Alice"}),
)),
]),
..Default::default()
}),
..Default::default()
};
let results = registry.execute_all(&response).await;
assert_eq!(results.len(), 2);
let multiply_result = results[0].as_ref().unwrap();
assert_eq!(multiply_result["product"], 10);
let greet_result = results[1].as_ref().unwrap();
assert_eq!(greet_result["greeting"], "Hello, Alice!");
}
#[tokio::test]
async fn test_registry_execute_all_no_tool_calls() {
let registry = ToolRegistry::new();
let response = ChatResponse {
message: Some(ResponseMessage {
content: Some("Hello!".to_string()),
..Default::default()
}),
..Default::default()
};
let results = registry.execute_all(&response).await;
assert!(results.is_empty());
}
#[test]
fn test_registry_clone() {
let mut registry = ToolRegistry::new();
registry.register(MultiplyTool);
let cloned = registry.clone();
assert_eq!(cloned.len(), 1);
assert!(cloned.contains("multiply"));
}
#[test]
fn test_registry_debug() {
let mut registry = ToolRegistry::new();
registry.register(MultiplyTool);
let debug = format!("{:?}", registry);
assert!(debug.contains("ToolRegistry"));
assert!(debug.contains("multiply"));
}
}