1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use crate::TypeError;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct AgentToolDefinition {
8 pub name: String,
9 pub description: String,
10 pub parameters: serde_json::Value, }
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolCall {
15 pub tool_name: String,
16 pub arguments: serde_json::Value,
17}
18
19#[derive(Debug)]
20pub struct ToolRegistry {
21 tools: HashMap<String, Box<dyn Tool + Send + Sync>>,
22}
23
24pub trait Tool: Send + Sync + std::fmt::Debug {
25 fn name(&self) -> &str;
26 fn description(&self) -> &str;
27 fn parameter_schema(&self) -> serde_json::Value;
28 fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, TypeError>;
29}
30
31impl ToolRegistry {
32 pub fn new() -> Self {
33 Self {
34 tools: HashMap::new(),
35 }
36 }
37
38 pub fn register_tool(&mut self, tool: Box<dyn Tool + Send + Sync>) {
39 self.tools.insert(tool.name().to_string(), tool);
40 }
41
42 pub fn execute(&self, call: &ToolCall) -> Result<serde_json::Value, TypeError> {
43 self.tools
44 .get(&call.tool_name)
45 .ok_or_else(|| TypeError::Error(format!("Tool {} not found", call.tool_name)))?
46 .execute(call.arguments.clone())
47 }
48
49 pub fn get_definitions(&self) -> Vec<AgentToolDefinition> {
50 self.tools
51 .values()
52 .map(|tool| AgentToolDefinition {
53 name: tool.name().to_string(),
54 description: tool.description().to_string(),
55 parameters: tool.parameter_schema(),
56 })
57 .collect()
58 }
59
60 pub fn len(&self) -> usize {
62 self.tools.len()
63 }
64
65 pub fn is_empty(&self) -> bool {
67 self.tools.is_empty()
68 }
69}
70
71impl Default for ToolRegistry {
72 fn default() -> Self {
73 Self::new()
74 }
75}