aagt_core/skills/tool/
mod.rs1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::Error;
11
12pub mod code_interpreter;
13pub mod cron;
14pub mod delegation;
15pub mod forge;
16pub mod memory;
17
18pub use cron::CronTool;
19pub use delegation::DelegateTool;
20pub use forge::ForgeSkill;
21pub use memory::{RememberThisTool, SearchHistoryTool, TieredSearchTool, FetchDocumentTool};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ToolDefinition {
26 pub name: String,
28 pub description: String,
30 pub parameters: serde_json::Value,
32 pub parameters_ts: Option<String>,
34 #[serde(default)]
36 pub is_binary: bool,
37 #[serde(default)]
39 pub is_verified: bool,
40 pub usage_guidelines: Option<String>,
42}
43
44#[async_trait]
46pub trait Tool: Send + Sync {
47 fn name(&self) -> String;
50
51 async fn definition(&self) -> ToolDefinition;
53
54 async fn call(&self, arguments: &str) -> anyhow::Result<String>;
56}
57
58#[derive(Clone)]
59pub struct ToolSet {
60 tools: Arc<parking_lot::RwLock<HashMap<String, Arc<dyn Tool>>>>,
61 cached_definitions: Arc<parking_lot::RwLock<HashMap<String, ToolDefinition>>>,
63}
64
65impl Default for ToolSet {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl ToolSet {
72 pub fn new() -> Self {
74 Self {
75 tools: Arc::new(parking_lot::RwLock::new(HashMap::new())),
76 cached_definitions: Arc::new(parking_lot::RwLock::new(HashMap::new())),
77 }
78 }
79
80 pub fn add<T: Tool + 'static>(&self, tool: T) -> &Self {
82 self.tools.write().insert(tool.name().to_string(), Arc::new(tool));
83 self
84 }
85
86 pub fn add_shared(&self, tool: Arc<dyn Tool>) -> &Self {
88 self.tools.write().insert(tool.name().to_string(), tool);
89 self
90 }
91
92 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
94 self.tools.read().get(name).cloned()
95 }
96
97 pub fn contains(&self, name: &str) -> bool {
99 self.tools.read().contains_key(name)
100 }
101
102 pub async fn definitions(&self) -> Vec<ToolDefinition> {
104 let mut defs = Vec::new();
105 let tools_snapshot = self.iter();
106
107 for (name, tool) in tools_snapshot {
108 let cached = {
110 self.cached_definitions.read().get(&name).cloned()
111 };
112
113 if let Some(def) = cached {
114 defs.push(def);
115 } else {
116 let def = tool.definition().await;
117 self.cached_definitions.write().insert(name, def.clone());
118 defs.push(def);
119 }
120 }
121 defs
122 }
123
124 pub async fn call(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
126 let tool = {
127 self.tools.read().get(name).cloned()
128 }.ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
129
130 tool.call(arguments).await
131 }
132
133 pub fn len(&self) -> usize {
135 self.tools.read().len()
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.tools.read().is_empty()
141 }
142
143 pub fn iter(&self) -> Vec<(String, Arc<dyn Tool>)> {
145 self.tools.read().iter().map(|(k, v)| (k.clone(), Arc::clone(v))).collect()
146 }
147}
148
149#[async_trait::async_trait]
150impl crate::agent::context::ContextInjector for ToolSet {
151 async fn inject(&self) -> crate::error::Result<Vec<crate::agent::message::Message>> {
152 if self.is_empty() {
153 return Ok(Vec::new());
154 }
155
156 let mut content = String::from("## Tool Definitions (TypeScript)\n\n");
157 content.push_str("You have access to the following tools. Use them to fulfill the user's request.\n\n");
158
159 let mut sorted_tools: Vec<_> = self.iter();
160 sorted_tools.sort_by_key(|(k, _)| k.clone());
161
162 for (name, tool) in sorted_tools {
163 let cached_def = {
164 self.cached_definitions.read().get(&name).cloned()
165 };
166
167 let def = if let Some(d) = cached_def {
168 d
169 } else {
170 let d = tool.definition().await;
171 self.cached_definitions.write().insert(name.clone(), d.clone());
172 d
173 };
174
175 content.push_str(&format!("### {}\n{}\n", name, def.description));
176 if let Some(guidelines) = def.usage_guidelines {
177 content.push_str(&format!("*Usage Guidelines*: {}\n", guidelines));
178 }
179 if let Some(ts) = def.parameters_ts {
180 content.push_str("```typescript\n");
181 content.push_str(&ts);
182 if !ts.ends_with('\n') {
183 content.push('\n');
184 }
185 content.push_str("```\n\n");
186 } else {
187 content.push_str("```json\n");
189 content.push_str(&serde_json::to_string_pretty(&def.parameters).unwrap_or_default());
190 content.push_str("\n```\n\n");
191 }
192 }
193
194 Ok(vec![crate::agent::message::Message::system(content)])
195 }
196}
197
198pub struct ToolSetBuilder {
200 tools: Vec<Arc<dyn Tool>>,
201}
202
203impl Default for ToolSetBuilder {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209impl ToolSetBuilder {
210 pub fn new() -> Self {
212 Self { tools: Vec::new() }
213 }
214
215 pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
217 self.tools.push(Arc::new(tool));
218 self
219 }
220
221 pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
223 self.tools.push(tool);
224 self
225 }
226
227 pub fn build(self) -> ToolSet {
229 let toolset = ToolSet::new();
230 for tool in self.tools {
231 toolset.add_shared(tool);
232 }
233 toolset
234 }
235}
236
237#[macro_export]
250macro_rules! simple_tool {
251 (
252 name: $name:expr,
253 description: $desc:expr,
254 parameters: $params:expr,
255 handler: $handler:expr
256 ) => {{
257 struct SimpleTool;
258
259 #[async_trait::async_trait]
260 impl $crate::tool::Tool for SimpleTool {
261 fn name(&self) -> String {
262 $name.to_string()
263 }
264
265 async fn definition(&self) -> $crate::tool::ToolDefinition {
266 $crate::tool::ToolDefinition {
267 name: $name.to_string(),
268 description: $desc.to_string(),
269 parameters: $params,
270 usage_guidelines: None,
271 is_binary: false,
272 is_verified: false,
273 parameters_ts: None,
274 }
275 }
276
277 async fn call(&self, arguments: &str) -> anyhow::Result<String> {
278 let handler = $handler;
279 handler(arguments).await
280 }
281 }
282
283 SimpleTool
284 }};
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 struct EchoTool;
292
293 #[async_trait]
294 impl Tool for EchoTool {
295 fn name(&self) -> String {
296 "echo".to_string()
297 }
298
299 async fn definition(&self) -> ToolDefinition {
300 ToolDefinition {
301 name: "echo".to_string(),
302 description: "Echo back the input".to_string(),
303 parameters: serde_json::json!({
304 "type": "object",
305 "properties": {
306 "message": {
307 "type": "string",
308 "description": "Message to echo"
309 }
310 },
311 "required": ["message"]
312 }),
313 parameters_ts: None,
314 is_binary: false,
315 is_verified: true, usage_guidelines: None,
317 }
318 }
319
320 async fn call(&self, arguments: &str) -> anyhow::Result<String> {
321 #[derive(Deserialize)]
322 struct Args {
323 message: String,
324 }
325 let args: Args = serde_json::from_str(arguments)
326 .map_err(|e| Error::ToolArguments {
327 tool_name: "echo".to_string(),
328 message: e.to_string(),
329 })?;
330 Ok(args.message)
331 }
332 }
333
334 #[tokio::test]
335 async fn test_toolset() {
336 let toolset = ToolSet::new();
337 toolset.add(EchoTool);
338
339 assert!(toolset.contains("echo"));
340 assert_eq!(toolset.len(), 1);
341
342 let result = toolset
343 .call("echo", r#"{"message": "hello"}"#)
344 .await
345 .expect("call should succeed");
346 assert_eq!(result, "hello");
347 }
348}