1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::{Error, Result};
11
12pub mod code_interpreter;
13pub mod memory;
14
15pub use memory::{RememberThisTool, SearchHistoryTool};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ToolDefinition {
20 pub name: String,
22 pub description: String,
24 pub parameters: serde_json::Value,
26}
27
28#[async_trait]
30pub trait Tool: Send + Sync {
31 fn name(&self) -> String;
34
35 async fn definition(&self) -> ToolDefinition;
37
38 async fn call(&self, arguments: &str) -> anyhow::Result<String>;
40}
41
42pub struct ToolSet {
44 tools: HashMap<String, Arc<dyn Tool>>,
45}
46
47impl Default for ToolSet {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl ToolSet {
54 pub fn new() -> Self {
56 Self {
57 tools: HashMap::new(),
58 }
59 }
60
61 pub fn add<T: Tool + 'static>(&mut self, tool: T) -> &mut Self {
63 self.tools.insert(tool.name().to_string(), Arc::new(tool));
64 self
65 }
66
67 pub fn add_shared(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
69 self.tools.insert(tool.name().to_string(), tool);
70 self
71 }
72
73 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
75 self.tools.get(name)
76 }
77
78 pub fn contains(&self, name: &str) -> bool {
80 self.tools.contains_key(name)
81 }
82
83 pub async fn definitions(&self) -> Vec<ToolDefinition> {
85 let mut defs = Vec::new();
86 for tool in self.tools.values() {
87 defs.push(tool.definition().await);
88 }
89 defs
90 }
91
92 pub async fn call(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
94 let tool = self
95 .tools
96 .get(name)
97 .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
98
99 tool.call(arguments).await
100 }
101
102 pub fn len(&self) -> usize {
104 self.tools.len()
105 }
106
107 pub fn is_empty(&self) -> bool {
109 self.tools.is_empty()
110 }
111
112 pub fn iter(&self) -> impl Iterator<Item = (&String, &Arc<dyn Tool>)> {
114 self.tools.iter()
115 }
116}
117
118pub struct ToolSetBuilder {
120 tools: Vec<Arc<dyn Tool>>,
121}
122
123impl Default for ToolSetBuilder {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl ToolSetBuilder {
130 pub fn new() -> Self {
132 Self { tools: Vec::new() }
133 }
134
135 pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
137 self.tools.push(Arc::new(tool));
138 self
139 }
140
141 pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
143 self.tools.push(tool);
144 self
145 }
146
147 pub fn build(self) -> ToolSet {
149 let mut toolset = ToolSet::new();
150 for tool in self.tools {
151 toolset.add_shared(tool);
152 }
153 toolset
154 }
155}
156
157#[macro_export]
170macro_rules! simple_tool {
171 (
172 name: $name:expr,
173 description: $desc:expr,
174 parameters: $params:expr,
175 handler: $handler:expr
176 ) => {{
177 struct SimpleTool;
178
179 #[async_trait::async_trait]
180 impl $crate::tool::Tool for SimpleTool {
181 fn name(&self) -> String {
182 $name.to_string()
183 }
184
185 async fn definition(&self) -> $crate::tool::ToolDefinition {
186 $crate::tool::ToolDefinition {
187 name: $name.to_string(),
188 description: $desc.to_string(),
189 parameters: $params,
190 }
191 }
192
193 async fn call(&self, arguments: &str) -> anyhow::Result<String> {
194 let handler = $handler;
195 handler(arguments).await
196 }
197 }
198
199 SimpleTool
200 }};
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 struct EchoTool;
208
209 #[async_trait]
210 impl Tool for EchoTool {
211 fn name(&self) -> String {
212 "echo".to_string()
213 }
214
215 async fn definition(&self) -> ToolDefinition {
216 ToolDefinition {
217 name: "echo".to_string(),
218 description: "Echo back the input".to_string(),
219 parameters: serde_json::json!({
220 "type": "object",
221 "properties": {
222 "message": {
223 "type": "string",
224 "description": "Message to echo"
225 }
226 },
227 "required": ["message"]
228 }),
229 }
230 }
231
232 async fn call(&self, arguments: &str) -> anyhow::Result<String> {
233 #[derive(Deserialize)]
234 struct Args {
235 message: String,
236 }
237 let args: Args = serde_json::from_str(arguments)
238 .map_err(|e| Error::ToolArguments {
239 tool_name: "echo".to_string(),
240 message: e.to_string(),
241 })?;
242 Ok(args.message)
243 }
244 }
245
246 #[tokio::test]
247 async fn test_toolset() {
248 let mut toolset = ToolSet::new();
249 toolset.add(EchoTool);
250
251 assert!(toolset.contains("echo"));
252 assert_eq!(toolset.len(), 1);
253
254 let result = toolset
255 .call("echo", r#"{"message": "hello"}"#)
256 .await
257 .expect("call should succeed");
258 assert_eq!(result, "hello");
259 }
260}