Skip to main content

aagt_core/skills/tool/
mod.rs

1//! Tool system for AI agents
2//!
3//! Provides the core abstraction for defining tools that AI agents can call.
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::Error;
11
12pub mod code_interpreter;
13pub mod cron;
14pub mod delegation;
15pub mod forge;
16pub mod memory;
17
18pub use cron::CronTool;
19pub use delegation::DelegateTool;
20pub use forge::ForgeSkill;
21pub use memory::{RememberThisTool, SearchHistoryTool, TieredSearchTool, FetchDocumentTool};
22
23/// Definition of a tool that can be sent to the LLM
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ToolDefinition {
26    /// Name of the tool
27    pub name: String,
28    /// Description for the LLM
29    pub description: String,
30    /// JSON Schema for parameters (Legacy/API)
31    pub parameters: serde_json::Value,
32    /// TypeScript interface definition (Preferred for System Prompt)
33    pub parameters_ts: Option<String>,
34    /// Whether this is a binary tool (e.g. Wasm)
35    #[serde(default)]
36    pub is_binary: bool,
37    /// Whether the tool is verified/trusted
38    #[serde(default)]
39    pub is_verified: bool,
40    /// Usage guidelines for the LLM (When to use, When NOT to use)
41    pub usage_guidelines: Option<String>,
42}
43
44/// Trait for implementing tools that AI agents can call
45#[async_trait]
46pub trait Tool: Send + Sync {
47    /// The name of this tool
48    /// The name of this tool
49    fn name(&self) -> String;
50
51    /// Get the tool definition for the LLM
52    async fn definition(&self) -> ToolDefinition;
53
54    /// Execute the tool with the given arguments (JSON string)
55    async fn call(&self, arguments: &str) -> anyhow::Result<String>;
56}
57
58#[derive(Clone)]
59pub struct ToolSet {
60    tools: Arc<parking_lot::RwLock<HashMap<String, Arc<dyn Tool>>>>,
61    /// Cached definitions to avoid async calls during prompt generation
62    cached_definitions: Arc<parking_lot::RwLock<HashMap<String, ToolDefinition>>>,
63}
64
65impl Default for ToolSet {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl ToolSet {
72    /// Create an empty toolset
73    pub fn new() -> Self {
74        Self {
75            tools: Arc::new(parking_lot::RwLock::new(HashMap::new())),
76            cached_definitions: Arc::new(parking_lot::RwLock::new(HashMap::new())),
77        }
78    }
79
80    /// Add a tool to the set
81    pub fn add<T: Tool + 'static>(&self, tool: T) -> &Self {
82        self.tools.write().insert(tool.name().to_string(), Arc::new(tool));
83        self
84    }
85
86    /// Add a shared tool to the set
87    pub fn add_shared(&self, tool: Arc<dyn Tool>) -> &Self {
88        self.tools.write().insert(tool.name().to_string(), tool);
89        self
90    }
91
92    /// Get a tool by name
93    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
94        self.tools.read().get(name).cloned()
95    }
96
97    /// Check if a tool exists
98    pub fn contains(&self, name: &str) -> bool {
99        self.tools.read().contains_key(name)
100    }
101
102    /// Get all tool definitions
103    pub async fn definitions(&self) -> Vec<ToolDefinition> {
104        let mut defs = Vec::new();
105        let tools_snapshot = self.iter();
106        
107        for (name, tool) in tools_snapshot {
108            // Check cache in a small block to ensure guard is dropped
109            let cached = {
110                self.cached_definitions.read().get(&name).cloned()
111            };
112
113            if let Some(def) = cached {
114                defs.push(def);
115            } else {
116                let def = tool.definition().await;
117                self.cached_definitions.write().insert(name, def.clone());
118                defs.push(def);
119            }
120        }
121        defs
122    }
123
124    /// Call a tool by name
125    pub async fn call(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
126        let tool = {
127            self.tools.read().get(name).cloned()
128        }.ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
129
130        tool.call(arguments).await
131    }
132
133    /// Get the number of tools
134    pub fn len(&self) -> usize {
135        self.tools.read().len()
136    }
137
138    /// Check if empty
139    pub fn is_empty(&self) -> bool {
140        self.tools.read().is_empty()
141    }
142
143    /// Iterate over tools
144    pub fn iter(&self) -> Vec<(String, Arc<dyn Tool>)> {
145        self.tools.read().iter().map(|(k, v)| (k.clone(), Arc::clone(v))).collect()
146    }
147}
148
149#[async_trait::async_trait]
150impl crate::agent::context::ContextInjector for ToolSet {
151    async fn inject(&self) -> crate::error::Result<Vec<crate::agent::message::Message>> {
152        if self.is_empty() {
153            return Ok(Vec::new());
154        }
155
156        let mut content = String::from("## Tool Definitions (TypeScript)\n\n");
157        content.push_str("You have access to the following tools. Use them to fulfill the user's request.\n\n");
158
159        let mut sorted_tools: Vec<_> = self.iter();
160        sorted_tools.sort_by_key(|(k, _)| k.clone());
161
162        for (name, tool) in sorted_tools {
163            let cached_def = {
164                self.cached_definitions.read().get(&name).cloned()
165            };
166
167            let def = if let Some(d) = cached_def {
168                d
169            } else {
170                let d = tool.definition().await;
171                self.cached_definitions.write().insert(name.clone(), d.clone());
172                d
173            };
174            
175            content.push_str(&format!("### {}\n{}\n", name, def.description));
176            if let Some(guidelines) = def.usage_guidelines {
177                content.push_str(&format!("*Usage Guidelines*: {}\n", guidelines));
178            }
179            if let Some(ts) = def.parameters_ts {
180                content.push_str("```typescript\n");
181                content.push_str(&ts);
182                if !ts.ends_with('\n') {
183                    content.push('\n');
184                }
185                content.push_str("```\n\n");
186            } else {
187                // Fallback to JSON if TS missing
188                content.push_str("```json\n");
189                content.push_str(&serde_json::to_string_pretty(&def.parameters).unwrap_or_default());
190                content.push_str("\n```\n\n");
191            }
192        }
193
194        Ok(vec![crate::agent::message::Message::system(content)])
195    }
196}
197
198/// Builder for creating a ToolSet
199pub struct ToolSetBuilder {
200    tools: Vec<Arc<dyn Tool>>,
201}
202
203impl Default for ToolSetBuilder {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209impl ToolSetBuilder {
210    /// Create a new builder
211    pub fn new() -> Self {
212        Self { tools: Vec::new() }
213    }
214
215    /// Add a tool
216    pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
217        self.tools.push(Arc::new(tool));
218        self
219    }
220
221    /// Add a shared tool
222    pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
223        self.tools.push(tool);
224        self
225    }
226
227    /// Build the ToolSet
228    pub fn build(self) -> ToolSet {
229        let toolset = ToolSet::new();
230        for tool in self.tools {
231            toolset.add_shared(tool);
232        }
233        toolset
234    }
235}
236
237/// Helper macro for creating simple tools
238/// 
239/// # Example
240/// ```ignore
241/// simple_tool!(
242///     name: "get_time",
243///     description: "Get the current time",
244///     handler: |_args| async {
245///         Ok(chrono::Utc::now().to_rfc3339())
246///     }
247/// );
248/// ```
249#[macro_export]
250macro_rules! simple_tool {
251    (
252        name: $name:expr,
253        description: $desc:expr,
254        parameters: $params:expr,
255        handler: $handler:expr
256    ) => {{
257        struct SimpleTool;
258
259        #[async_trait::async_trait]
260        impl $crate::tool::Tool for SimpleTool {
261            fn name(&self) -> String {
262                $name.to_string()
263            }
264
265            async fn definition(&self) -> $crate::tool::ToolDefinition {
266                $crate::tool::ToolDefinition {
267                    name: $name.to_string(),
268                    description: $desc.to_string(),
269                    parameters: $params,
270                    usage_guidelines: None,
271                    is_binary: false,
272                    is_verified: false,
273                    parameters_ts: None,
274                }
275            }
276
277            async fn call(&self, arguments: &str) -> anyhow::Result<String> {
278                let handler = $handler;
279                handler(arguments).await
280            }
281        }
282
283        SimpleTool
284    }};
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    struct EchoTool;
292
293    #[async_trait]
294    impl Tool for EchoTool {
295        fn name(&self) -> String {
296            "echo".to_string()
297        }
298
299        async fn definition(&self) -> ToolDefinition {
300            ToolDefinition {
301                name: "echo".to_string(),
302                description: "Echo back the input".to_string(),
303                parameters: serde_json::json!({
304                    "type": "object",
305                    "properties": {
306                        "message": {
307                            "type": "string",
308                            "description": "Message to echo"
309                        }
310                    },
311                    "required": ["message"]
312                }),
313                parameters_ts: None,
314                is_binary: false,
315                is_verified: true, // Internal tools are verified
316                usage_guidelines: None,
317            }
318        }
319
320        async fn call(&self, arguments: &str) -> anyhow::Result<String> {
321            #[derive(Deserialize)]
322            struct Args {
323                message: String,
324            }
325            let args: Args = serde_json::from_str(arguments)
326                .map_err(|e| Error::ToolArguments {
327                    tool_name: "echo".to_string(),
328                    message: e.to_string(),
329                })?;
330            Ok(args.message)
331        }
332    }
333
334    #[tokio::test]
335    async fn test_toolset() {
336        let toolset = ToolSet::new();
337        toolset.add(EchoTool);
338
339        assert!(toolset.contains("echo"));
340        assert_eq!(toolset.len(), 1);
341
342        let result = toolset
343            .call("echo", r#"{"message": "hello"}"#)
344            .await
345            .expect("call should succeed");
346        assert_eq!(result, "hello");
347    }
348}