Skip to main content

aagt_core/skills/tool/
mod.rs

1//! Tool system for AI agents
2//!
3//! Provides the core abstraction for defining tools that AI agents can call.
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::Error;
11
12pub mod code_interpreter;
13pub mod cron;
14pub mod delegation;
15pub mod memory;
16
17pub use cron::CronTool;
18pub use delegation::DelegateTool;
19pub use memory::{RememberThisTool, SearchHistoryTool, TieredSearchTool, FetchDocumentTool};
20
21/// Definition of a tool that can be sent to the LLM
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ToolDefinition {
24    /// Name of the tool
25    pub name: String,
26    /// Description for the LLM
27    pub description: String,
28    /// JSON Schema for parameters (Legacy/API)
29    pub parameters: serde_json::Value,
30    /// TypeScript interface definition (Preferred for System Prompt)
31    pub parameters_ts: Option<String>,
32    /// Whether this is a binary tool (e.g. Wasm)
33    #[serde(default)]
34    pub is_binary: bool,
35    /// Whether the tool is verified/trusted
36    #[serde(default)]
37    pub is_verified: bool,
38}
39
40/// Trait for implementing tools that AI agents can call
41#[async_trait]
42pub trait Tool: Send + Sync {
43    /// The name of this tool
44    /// The name of this tool
45    fn name(&self) -> String;
46
47    /// Get the tool definition for the LLM
48    async fn definition(&self) -> ToolDefinition;
49
50    /// Execute the tool with the given arguments (JSON string)
51    async fn call(&self, arguments: &str) -> anyhow::Result<String>;
52}
53
54#[derive(Clone)]
55pub struct ToolSet {
56    tools: HashMap<String, Arc<dyn Tool>>,
57    /// Cached definitions to avoid async calls during prompt generation
58    cached_definitions: Arc<parking_lot::RwLock<HashMap<String, ToolDefinition>>>,
59}
60
61impl Default for ToolSet {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl ToolSet {
68    /// Create an empty toolset
69    pub fn new() -> Self {
70        Self {
71            tools: HashMap::new(),
72            cached_definitions: Arc::new(parking_lot::RwLock::new(HashMap::new())),
73        }
74    }
75
76    /// Add a tool to the set
77    pub fn add<T: Tool + 'static>(&mut self, tool: T) -> &mut Self {
78        self.tools.insert(tool.name().to_string(), Arc::new(tool));
79        self
80    }
81
82    /// Add a shared tool to the set
83    pub fn add_shared(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
84        self.tools.insert(tool.name().to_string(), tool);
85        self
86    }
87
88    /// Get a tool by name
89    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
90        self.tools.get(name)
91    }
92
93    /// Check if a tool exists
94    pub fn contains(&self, name: &str) -> bool {
95        self.tools.contains_key(name)
96    }
97
98    /// Get all tool definitions
99    pub async fn definitions(&self) -> Vec<ToolDefinition> {
100        let mut defs = Vec::new();
101        for (name, tool) in &self.tools {
102            // Check cache in a small block to ensure guard is dropped
103            let cached = {
104                self.cached_definitions.read().get(name).cloned()
105            };
106
107            if let Some(def) = cached {
108                defs.push(def);
109            } else {
110                let def = tool.definition().await;
111                self.cached_definitions.write().insert(name.clone(), def.clone());
112                defs.push(def);
113            }
114        }
115        defs
116    }
117
118    /// Call a tool by name
119    pub async fn call(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
120        let tool = self
121            .tools
122            .get(name)
123            .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
124
125        tool.call(arguments).await
126    }
127
128    /// Get the number of tools
129    pub fn len(&self) -> usize {
130        self.tools.len()
131    }
132
133    /// Check if empty
134    pub fn is_empty(&self) -> bool {
135        self.tools.is_empty()
136    }
137
138    /// Iterate over tools
139    pub fn iter(&self) -> impl Iterator<Item = (&String, &Arc<dyn Tool>)> {
140        self.tools.iter()
141    }
142}
143
144#[async_trait::async_trait]
145impl crate::agent::context::ContextInjector for ToolSet {
146    async fn inject(&self) -> crate::error::Result<Vec<crate::agent::message::Message>> {
147        if self.tools.is_empty() {
148            return Ok(Vec::new());
149        }
150
151        let mut content = String::from("## Tool Definitions (TypeScript)\n\n");
152        content.push_str("You have access to the following tools. Use them to fulfill the user's request.\n\n");
153
154        // Sort for determinism
155        let mut sorted_tools: Vec<_> = self.tools.iter().collect();
156        sorted_tools.sort_by_key(|(k, _)| *k);
157
158        for (name, tool) in sorted_tools {
159            let cached_def = {
160                self.cached_definitions.read().get(name).cloned()
161            };
162
163            let def = if let Some(d) = cached_def {
164                d
165            } else {
166                let d = tool.definition().await;
167                self.cached_definitions.write().insert(name.clone(), d.clone());
168                d
169            };
170            
171            content.push_str(&format!("### {}\n{}\n", name, def.description));
172            if let Some(ts) = def.parameters_ts {
173                content.push_str("```typescript\n");
174                content.push_str(&ts);
175                if !ts.ends_with('\n') {
176                    content.push('\n');
177                }
178                content.push_str("```\n\n");
179            } else {
180                // Fallback to JSON if TS missing
181                content.push_str("```json\n");
182                content.push_str(&serde_json::to_string_pretty(&def.parameters).unwrap_or_default());
183                content.push_str("\n```\n\n");
184            }
185        }
186
187        Ok(vec![crate::agent::message::Message::system(content)])
188    }
189}
190
191/// Builder for creating a ToolSet
192pub struct ToolSetBuilder {
193    tools: Vec<Arc<dyn Tool>>,
194}
195
196impl Default for ToolSetBuilder {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202impl ToolSetBuilder {
203    /// Create a new builder
204    pub fn new() -> Self {
205        Self { tools: Vec::new() }
206    }
207
208    /// Add a tool
209    pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
210        self.tools.push(Arc::new(tool));
211        self
212    }
213
214    /// Add a shared tool
215    pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
216        self.tools.push(tool);
217        self
218    }
219
220    /// Build the ToolSet
221    pub fn build(self) -> ToolSet {
222        let mut toolset = ToolSet::new();
223        for tool in self.tools {
224            toolset.add_shared(tool);
225        }
226        toolset
227    }
228}
229
230/// Helper macro for creating simple tools
231/// 
232/// # Example
233/// ```ignore
234/// simple_tool!(
235///     name: "get_time",
236///     description: "Get the current time",
237///     handler: |_args| async {
238///         Ok(chrono::Utc::now().to_rfc3339())
239///     }
240/// );
241/// ```
242#[macro_export]
243macro_rules! simple_tool {
244    (
245        name: $name:expr,
246        description: $desc:expr,
247        parameters: $params:expr,
248        handler: $handler:expr
249    ) => {{
250        struct SimpleTool;
251
252        #[async_trait::async_trait]
253        impl $crate::tool::Tool for SimpleTool {
254            fn name(&self) -> String {
255                $name.to_string()
256            }
257
258            async fn definition(&self) -> $crate::tool::ToolDefinition {
259                $crate::tool::ToolDefinition {
260                    name: $name.to_string(),
261                    description: $desc.to_string(),
262                    parameters: $params,
263                }
264            }
265
266            async fn call(&self, arguments: &str) -> anyhow::Result<String> {
267                let handler = $handler;
268                handler(arguments).await
269            }
270        }
271
272        SimpleTool
273    }};
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    struct EchoTool;
281
282    #[async_trait]
283    impl Tool for EchoTool {
284        fn name(&self) -> String {
285            "echo".to_string()
286        }
287
288        async fn definition(&self) -> ToolDefinition {
289            ToolDefinition {
290                name: "echo".to_string(),
291                description: "Echo back the input".to_string(),
292                parameters: serde_json::json!({
293                    "type": "object",
294                    "properties": {
295                        "message": {
296                            "type": "string",
297                            "description": "Message to echo"
298                        }
299                    },
300                    "required": ["message"]
301                }),
302                parameters_ts: None,
303                is_binary: false,
304                is_verified: true, // Internal tools are verified
305            }
306        }
307
308        async fn call(&self, arguments: &str) -> anyhow::Result<String> {
309            #[derive(Deserialize)]
310            struct Args {
311                message: String,
312            }
313            let args: Args = serde_json::from_str(arguments)
314                .map_err(|e| Error::ToolArguments {
315                    tool_name: "echo".to_string(),
316                    message: e.to_string(),
317                })?;
318            Ok(args.message)
319        }
320    }
321
322    #[tokio::test]
323    async fn test_toolset() {
324        let mut toolset = ToolSet::new();
325        toolset.add(EchoTool);
326
327        assert!(toolset.contains("echo"));
328        assert_eq!(toolset.len(), 1);
329
330        let result = toolset
331            .call("echo", r#"{"message": "hello"}"#)
332            .await
333            .expect("call should succeed");
334        assert_eq!(result, "hello");
335    }
336}