aagt_core/skills/tool/
mod.rs1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::Error;
11
12pub mod code_interpreter;
13pub mod cron;
14pub mod delegation;
15pub mod memory;
16
17pub use cron::CronTool;
18pub use delegation::DelegateTool;
19pub use memory::{RememberThisTool, SearchHistoryTool, TieredSearchTool, FetchDocumentTool};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ToolDefinition {
24 pub name: String,
26 pub description: String,
28 pub parameters: serde_json::Value,
30 pub parameters_ts: Option<String>,
32 #[serde(default)]
34 pub is_binary: bool,
35 #[serde(default)]
37 pub is_verified: bool,
38}
39
40#[async_trait]
42pub trait Tool: Send + Sync {
43 fn name(&self) -> String;
46
47 async fn definition(&self) -> ToolDefinition;
49
50 async fn call(&self, arguments: &str) -> anyhow::Result<String>;
52}
53
54#[derive(Clone)]
55pub struct ToolSet {
56 tools: HashMap<String, Arc<dyn Tool>>,
57 cached_definitions: Arc<parking_lot::RwLock<HashMap<String, ToolDefinition>>>,
59}
60
61impl Default for ToolSet {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl ToolSet {
68 pub fn new() -> Self {
70 Self {
71 tools: HashMap::new(),
72 cached_definitions: Arc::new(parking_lot::RwLock::new(HashMap::new())),
73 }
74 }
75
76 pub fn add<T: Tool + 'static>(&mut self, tool: T) -> &mut Self {
78 self.tools.insert(tool.name().to_string(), Arc::new(tool));
79 self
80 }
81
82 pub fn add_shared(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
84 self.tools.insert(tool.name().to_string(), tool);
85 self
86 }
87
88 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
90 self.tools.get(name)
91 }
92
93 pub fn contains(&self, name: &str) -> bool {
95 self.tools.contains_key(name)
96 }
97
98 pub async fn definitions(&self) -> Vec<ToolDefinition> {
100 let mut defs = Vec::new();
101 for (name, tool) in &self.tools {
102 let cached = {
104 self.cached_definitions.read().get(name).cloned()
105 };
106
107 if let Some(def) = cached {
108 defs.push(def);
109 } else {
110 let def = tool.definition().await;
111 self.cached_definitions.write().insert(name.clone(), def.clone());
112 defs.push(def);
113 }
114 }
115 defs
116 }
117
118 pub async fn call(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
120 let tool = self
121 .tools
122 .get(name)
123 .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
124
125 tool.call(arguments).await
126 }
127
128 pub fn len(&self) -> usize {
130 self.tools.len()
131 }
132
133 pub fn is_empty(&self) -> bool {
135 self.tools.is_empty()
136 }
137
138 pub fn iter(&self) -> impl Iterator<Item = (&String, &Arc<dyn Tool>)> {
140 self.tools.iter()
141 }
142}
143
144#[async_trait::async_trait]
145impl crate::agent::context::ContextInjector for ToolSet {
146 async fn inject(&self) -> crate::error::Result<Vec<crate::agent::message::Message>> {
147 if self.tools.is_empty() {
148 return Ok(Vec::new());
149 }
150
151 let mut content = String::from("## Tool Definitions (TypeScript)\n\n");
152 content.push_str("You have access to the following tools. Use them to fulfill the user's request.\n\n");
153
154 let mut sorted_tools: Vec<_> = self.tools.iter().collect();
156 sorted_tools.sort_by_key(|(k, _)| *k);
157
158 for (name, tool) in sorted_tools {
159 let cached_def = {
160 self.cached_definitions.read().get(name).cloned()
161 };
162
163 let def = if let Some(d) = cached_def {
164 d
165 } else {
166 let d = tool.definition().await;
167 self.cached_definitions.write().insert(name.clone(), d.clone());
168 d
169 };
170
171 content.push_str(&format!("### {}\n{}\n", name, def.description));
172 if let Some(ts) = def.parameters_ts {
173 content.push_str("```typescript\n");
174 content.push_str(&ts);
175 if !ts.ends_with('\n') {
176 content.push('\n');
177 }
178 content.push_str("```\n\n");
179 } else {
180 content.push_str("```json\n");
182 content.push_str(&serde_json::to_string_pretty(&def.parameters).unwrap_or_default());
183 content.push_str("\n```\n\n");
184 }
185 }
186
187 Ok(vec![crate::agent::message::Message::system(content)])
188 }
189}
190
191pub struct ToolSetBuilder {
193 tools: Vec<Arc<dyn Tool>>,
194}
195
196impl Default for ToolSetBuilder {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202impl ToolSetBuilder {
203 pub fn new() -> Self {
205 Self { tools: Vec::new() }
206 }
207
208 pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
210 self.tools.push(Arc::new(tool));
211 self
212 }
213
214 pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
216 self.tools.push(tool);
217 self
218 }
219
220 pub fn build(self) -> ToolSet {
222 let mut toolset = ToolSet::new();
223 for tool in self.tools {
224 toolset.add_shared(tool);
225 }
226 toolset
227 }
228}
229
230#[macro_export]
243macro_rules! simple_tool {
244 (
245 name: $name:expr,
246 description: $desc:expr,
247 parameters: $params:expr,
248 handler: $handler:expr
249 ) => {{
250 struct SimpleTool;
251
252 #[async_trait::async_trait]
253 impl $crate::tool::Tool for SimpleTool {
254 fn name(&self) -> String {
255 $name.to_string()
256 }
257
258 async fn definition(&self) -> $crate::tool::ToolDefinition {
259 $crate::tool::ToolDefinition {
260 name: $name.to_string(),
261 description: $desc.to_string(),
262 parameters: $params,
263 }
264 }
265
266 async fn call(&self, arguments: &str) -> anyhow::Result<String> {
267 let handler = $handler;
268 handler(arguments).await
269 }
270 }
271
272 SimpleTool
273 }};
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 struct EchoTool;
281
282 #[async_trait]
283 impl Tool for EchoTool {
284 fn name(&self) -> String {
285 "echo".to_string()
286 }
287
288 async fn definition(&self) -> ToolDefinition {
289 ToolDefinition {
290 name: "echo".to_string(),
291 description: "Echo back the input".to_string(),
292 parameters: serde_json::json!({
293 "type": "object",
294 "properties": {
295 "message": {
296 "type": "string",
297 "description": "Message to echo"
298 }
299 },
300 "required": ["message"]
301 }),
302 parameters_ts: None,
303 is_binary: false,
304 is_verified: true, }
306 }
307
308 async fn call(&self, arguments: &str) -> anyhow::Result<String> {
309 #[derive(Deserialize)]
310 struct Args {
311 message: String,
312 }
313 let args: Args = serde_json::from_str(arguments)
314 .map_err(|e| Error::ToolArguments {
315 tool_name: "echo".to_string(),
316 message: e.to_string(),
317 })?;
318 Ok(args.message)
319 }
320 }
321
322 #[tokio::test]
323 async fn test_toolset() {
324 let mut toolset = ToolSet::new();
325 toolset.add(EchoTool);
326
327 assert!(toolset.contains("echo"));
328 assert_eq!(toolset.len(), 1);
329
330 let result = toolset
331 .call("echo", r#"{"message": "hello"}"#)
332 .await
333 .expect("call should succeed");
334 assert_eq!(result, "hello");
335 }
336}