1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::{Error, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolDefinition {
15 pub name: String,
17 pub description: String,
19 pub parameters: serde_json::Value,
21}
22
23#[async_trait]
25pub trait Tool: Send + Sync {
26 fn name(&self) -> String;
29
30 async fn definition(&self) -> ToolDefinition;
32
33 async fn call(&self, arguments: &str) -> Result<String>;
35}
36
37pub struct ToolSet {
39 tools: HashMap<String, Arc<dyn Tool>>,
40}
41
42impl Default for ToolSet {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl ToolSet {
49 pub fn new() -> Self {
51 Self {
52 tools: HashMap::new(),
53 }
54 }
55
56 pub fn add<T: Tool + 'static>(&mut self, tool: T) -> &mut Self {
58 self.tools.insert(tool.name().to_string(), Arc::new(tool));
59 self
60 }
61
62 pub fn add_shared(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
64 self.tools.insert(tool.name().to_string(), tool);
65 self
66 }
67
68 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
70 self.tools.get(name)
71 }
72
73 pub fn contains(&self, name: &str) -> bool {
75 self.tools.contains_key(name)
76 }
77
78 pub async fn definitions(&self) -> Vec<ToolDefinition> {
80 let mut defs = Vec::new();
81 for tool in self.tools.values() {
82 defs.push(tool.definition().await);
83 }
84 defs
85 }
86
87 pub async fn call(&self, name: &str, arguments: &str) -> Result<String> {
89 let tool = self
90 .tools
91 .get(name)
92 .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
93
94 tool.call(arguments).await
95 }
96
97 pub fn len(&self) -> usize {
99 self.tools.len()
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.tools.is_empty()
105 }
106
107 pub fn iter(&self) -> impl Iterator<Item = (&String, &Arc<dyn Tool>)> {
109 self.tools.iter()
110 }
111}
112
113pub struct ToolSetBuilder {
115 tools: Vec<Arc<dyn Tool>>,
116}
117
118impl Default for ToolSetBuilder {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl ToolSetBuilder {
125 pub fn new() -> Self {
127 Self { tools: Vec::new() }
128 }
129
130 pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
132 self.tools.push(Arc::new(tool));
133 self
134 }
135
136 pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
138 self.tools.push(tool);
139 self
140 }
141
142 pub fn build(self) -> ToolSet {
144 let mut toolset = ToolSet::new();
145 for tool in self.tools {
146 toolset.add_shared(tool);
147 }
148 toolset
149 }
150}
151
152#[macro_export]
165macro_rules! simple_tool {
166 (
167 name: $name:expr,
168 description: $desc:expr,
169 parameters: $params:expr,
170 handler: $handler:expr
171 ) => {{
172 struct SimpleTool;
173
174 #[async_trait::async_trait]
175 impl $crate::tool::Tool for SimpleTool {
176 fn name(&self) -> String {
177 $name.to_string()
178 }
179
180 async fn definition(&self) -> $crate::tool::ToolDefinition {
181 $crate::tool::ToolDefinition {
182 name: $name.to_string(),
183 description: $desc.to_string(),
184 parameters: $params,
185 }
186 }
187
188 async fn call(&self, arguments: &str) -> $crate::error::Result<String> {
189 let handler = $handler;
190 handler(arguments).await
191 }
192 }
193
194 SimpleTool
195 }};
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 struct EchoTool;
203
204 #[async_trait]
205 impl Tool for EchoTool {
206 fn name(&self) -> String {
207 "echo".to_string()
208 }
209
210 async fn definition(&self) -> ToolDefinition {
211 ToolDefinition {
212 name: "echo".to_string(),
213 description: "Echo back the input".to_string(),
214 parameters: serde_json::json!({
215 "type": "object",
216 "properties": {
217 "message": {
218 "type": "string",
219 "description": "Message to echo"
220 }
221 },
222 "required": ["message"]
223 }),
224 }
225 }
226
227 async fn call(&self, arguments: &str) -> Result<String> {
228 #[derive(Deserialize)]
229 struct Args {
230 message: String,
231 }
232 let args: Args = serde_json::from_str(arguments)
233 .map_err(|e| Error::ToolArguments {
234 tool_name: "echo".to_string(),
235 message: e.to_string(),
236 })?;
237 Ok(args.message)
238 }
239 }
240
241 #[tokio::test]
242 async fn test_toolset() {
243 let mut toolset = ToolSet::new();
244 toolset.add(EchoTool);
245
246 assert!(toolset.contains("echo"));
247 assert_eq!(toolset.len(), 1);
248
249 let result = toolset
250 .call("echo", r#"{"message": "hello"}"#)
251 .await
252 .expect("call should succeed");
253 assert_eq!(result, "hello");
254 }
255}