1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::{Error, Result};
11
12pub mod memory;
13
14pub use memory::{RememberThisTool, SearchHistoryTool};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ToolDefinition {
19 pub name: String,
21 pub description: String,
23 pub parameters: serde_json::Value,
25}
26
27#[async_trait]
29pub trait Tool: Send + Sync {
30 fn name(&self) -> String;
33
34 async fn definition(&self) -> ToolDefinition;
36
37 async fn call(&self, arguments: &str) -> Result<String>;
39}
40
41pub struct ToolSet {
43 tools: HashMap<String, Arc<dyn Tool>>,
44}
45
46impl Default for ToolSet {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl ToolSet {
53 pub fn new() -> Self {
55 Self {
56 tools: HashMap::new(),
57 }
58 }
59
60 pub fn add<T: Tool + 'static>(&mut self, tool: T) -> &mut Self {
62 self.tools.insert(tool.name().to_string(), Arc::new(tool));
63 self
64 }
65
66 pub fn add_shared(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
68 self.tools.insert(tool.name().to_string(), tool);
69 self
70 }
71
72 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
74 self.tools.get(name)
75 }
76
77 pub fn contains(&self, name: &str) -> bool {
79 self.tools.contains_key(name)
80 }
81
82 pub async fn definitions(&self) -> Vec<ToolDefinition> {
84 let mut defs = Vec::new();
85 for tool in self.tools.values() {
86 defs.push(tool.definition().await);
87 }
88 defs
89 }
90
91 pub async fn call(&self, name: &str, arguments: &str) -> Result<String> {
93 let tool = self
94 .tools
95 .get(name)
96 .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
97
98 tool.call(arguments).await
99 }
100
101 pub fn len(&self) -> usize {
103 self.tools.len()
104 }
105
106 pub fn is_empty(&self) -> bool {
108 self.tools.is_empty()
109 }
110
111 pub fn iter(&self) -> impl Iterator<Item = (&String, &Arc<dyn Tool>)> {
113 self.tools.iter()
114 }
115}
116
117pub struct ToolSetBuilder {
119 tools: Vec<Arc<dyn Tool>>,
120}
121
122impl Default for ToolSetBuilder {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl ToolSetBuilder {
129 pub fn new() -> Self {
131 Self { tools: Vec::new() }
132 }
133
134 pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
136 self.tools.push(Arc::new(tool));
137 self
138 }
139
140 pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
142 self.tools.push(tool);
143 self
144 }
145
146 pub fn build(self) -> ToolSet {
148 let mut toolset = ToolSet::new();
149 for tool in self.tools {
150 toolset.add_shared(tool);
151 }
152 toolset
153 }
154}
155
156#[macro_export]
169macro_rules! simple_tool {
170 (
171 name: $name:expr,
172 description: $desc:expr,
173 parameters: $params:expr,
174 handler: $handler:expr
175 ) => {{
176 struct SimpleTool;
177
178 #[async_trait::async_trait]
179 impl $crate::tool::Tool for SimpleTool {
180 fn name(&self) -> String {
181 $name.to_string()
182 }
183
184 async fn definition(&self) -> $crate::tool::ToolDefinition {
185 $crate::tool::ToolDefinition {
186 name: $name.to_string(),
187 description: $desc.to_string(),
188 parameters: $params,
189 }
190 }
191
192 async fn call(&self, arguments: &str) -> $crate::error::Result<String> {
193 let handler = $handler;
194 handler(arguments).await
195 }
196 }
197
198 SimpleTool
199 }};
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 struct EchoTool;
207
208 #[async_trait]
209 impl Tool for EchoTool {
210 fn name(&self) -> String {
211 "echo".to_string()
212 }
213
214 async fn definition(&self) -> ToolDefinition {
215 ToolDefinition {
216 name: "echo".to_string(),
217 description: "Echo back the input".to_string(),
218 parameters: serde_json::json!({
219 "type": "object",
220 "properties": {
221 "message": {
222 "type": "string",
223 "description": "Message to echo"
224 }
225 },
226 "required": ["message"]
227 }),
228 }
229 }
230
231 async fn call(&self, arguments: &str) -> Result<String> {
232 #[derive(Deserialize)]
233 struct Args {
234 message: String,
235 }
236 let args: Args = serde_json::from_str(arguments)
237 .map_err(|e| Error::ToolArguments {
238 tool_name: "echo".to_string(),
239 message: e.to_string(),
240 })?;
241 Ok(args.message)
242 }
243 }
244
245 #[tokio::test]
246 async fn test_toolset() {
247 let mut toolset = ToolSet::new();
248 toolset.add(EchoTool);
249
250 assert!(toolset.contains("echo"));
251 assert_eq!(toolset.len(), 1);
252
253 let result = toolset
254 .call("echo", r#"{"message": "hello"}"#)
255 .await
256 .expect("call should succeed");
257 assert_eq!(result, "hello");
258 }
259}