Skip to main content

aagt_core/tool/
mod.rs

1//! Tool system for AI agents
2//!
3//! Provides the core abstraction for defining tools that AI agents can call.
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::{Error, Result};
11
12pub mod memory;
13
14pub use memory::{RememberThisTool, SearchHistoryTool};
15
16/// Definition of a tool that can be sent to the LLM
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ToolDefinition {
19    /// Name of the tool
20    pub name: String,
21    /// Description for the LLM
22    pub description: String,
23    /// JSON Schema for parameters
24    pub parameters: serde_json::Value,
25}
26
27/// Trait for implementing tools that AI agents can call
28#[async_trait]
29pub trait Tool: Send + Sync {
30    /// The name of this tool
31    /// The name of this tool
32    fn name(&self) -> String;
33
34    /// Get the tool definition for the LLM
35    async fn definition(&self) -> ToolDefinition;
36
37    /// Execute the tool with the given arguments (JSON string)
38    async fn call(&self, arguments: &str) -> Result<String>;
39}
40
41/// A collection of tools available to an agent
42pub struct ToolSet {
43    tools: HashMap<String, Arc<dyn Tool>>,
44}
45
46impl Default for ToolSet {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl ToolSet {
53    /// Create an empty toolset
54    pub fn new() -> Self {
55        Self {
56            tools: HashMap::new(),
57        }
58    }
59
60    /// Add a tool to the set
61    pub fn add<T: Tool + 'static>(&mut self, tool: T) -> &mut Self {
62        self.tools.insert(tool.name().to_string(), Arc::new(tool));
63        self
64    }
65
66    /// Add a shared tool to the set
67    pub fn add_shared(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
68        self.tools.insert(tool.name().to_string(), tool);
69        self
70    }
71
72    /// Get a tool by name
73    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
74        self.tools.get(name)
75    }
76
77    /// Check if a tool exists
78    pub fn contains(&self, name: &str) -> bool {
79        self.tools.contains_key(name)
80    }
81
82    /// Get all tool definitions
83    pub async fn definitions(&self) -> Vec<ToolDefinition> {
84        let mut defs = Vec::new();
85        for tool in self.tools.values() {
86            defs.push(tool.definition().await);
87        }
88        defs
89    }
90
91    /// Call a tool by name
92    pub async fn call(&self, name: &str, arguments: &str) -> Result<String> {
93        let tool = self
94            .tools
95            .get(name)
96            .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
97
98        tool.call(arguments).await
99    }
100
101    /// Get the number of tools
102    pub fn len(&self) -> usize {
103        self.tools.len()
104    }
105
106    /// Check if empty
107    pub fn is_empty(&self) -> bool {
108        self.tools.is_empty()
109    }
110
111    /// Iterate over tools
112    pub fn iter(&self) -> impl Iterator<Item = (&String, &Arc<dyn Tool>)> {
113        self.tools.iter()
114    }
115}
116
117/// Builder for creating a ToolSet
118pub struct ToolSetBuilder {
119    tools: Vec<Arc<dyn Tool>>,
120}
121
122impl Default for ToolSetBuilder {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128impl ToolSetBuilder {
129    /// Create a new builder
130    pub fn new() -> Self {
131        Self { tools: Vec::new() }
132    }
133
134    /// Add a tool
135    pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
136        self.tools.push(Arc::new(tool));
137        self
138    }
139
140    /// Add a shared tool
141    pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
142        self.tools.push(tool);
143        self
144    }
145
146    /// Build the ToolSet
147    pub fn build(self) -> ToolSet {
148        let mut toolset = ToolSet::new();
149        for tool in self.tools {
150            toolset.add_shared(tool);
151        }
152        toolset
153    }
154}
155
156/// Helper macro for creating simple tools
157/// 
158/// # Example
159/// ```ignore
160/// simple_tool!(
161///     name: "get_time",
162///     description: "Get the current time",
163///     handler: |_args| async {
164///         Ok(chrono::Utc::now().to_rfc3339())
165///     }
166/// );
167/// ```
168#[macro_export]
169macro_rules! simple_tool {
170    (
171        name: $name:expr,
172        description: $desc:expr,
173        parameters: $params:expr,
174        handler: $handler:expr
175    ) => {{
176        struct SimpleTool;
177
178        #[async_trait::async_trait]
179        impl $crate::tool::Tool for SimpleTool {
180            fn name(&self) -> String {
181                $name.to_string()
182            }
183
184            async fn definition(&self) -> $crate::tool::ToolDefinition {
185                $crate::tool::ToolDefinition {
186                    name: $name.to_string(),
187                    description: $desc.to_string(),
188                    parameters: $params,
189                }
190            }
191
192            async fn call(&self, arguments: &str) -> $crate::error::Result<String> {
193                let handler = $handler;
194                handler(arguments).await
195            }
196        }
197
198        SimpleTool
199    }};
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    struct EchoTool;
207
208    #[async_trait]
209    impl Tool for EchoTool {
210        fn name(&self) -> String {
211            "echo".to_string()
212        }
213
214        async fn definition(&self) -> ToolDefinition {
215            ToolDefinition {
216                name: "echo".to_string(),
217                description: "Echo back the input".to_string(),
218                parameters: serde_json::json!({
219                    "type": "object",
220                    "properties": {
221                        "message": {
222                            "type": "string",
223                            "description": "Message to echo"
224                        }
225                    },
226                    "required": ["message"]
227                }),
228            }
229        }
230
231        async fn call(&self, arguments: &str) -> Result<String> {
232            #[derive(Deserialize)]
233            struct Args {
234                message: String,
235            }
236            let args: Args = serde_json::from_str(arguments)
237                .map_err(|e| Error::ToolArguments {
238                    tool_name: "echo".to_string(),
239                    message: e.to_string(),
240                })?;
241            Ok(args.message)
242        }
243    }
244
245    #[tokio::test]
246    async fn test_toolset() {
247        let mut toolset = ToolSet::new();
248        toolset.add(EchoTool);
249
250        assert!(toolset.contains("echo"));
251        assert_eq!(toolset.len(), 1);
252
253        let result = toolset
254            .call("echo", r#"{"message": "hello"}"#)
255            .await
256            .expect("call should succeed");
257        assert_eq!(result, "hello");
258    }
259}