Skip to main content

aagt_core/tool/
mod.rs

1//! Tool system for AI agents
2//!
3//! Provides the core abstraction for defining tools that AI agents can call.
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::{Error, Result};
11
12pub mod code_interpreter;
13pub mod memory;
14
15pub use memory::{RememberThisTool, SearchHistoryTool};
16
17/// Definition of a tool that can be sent to the LLM
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ToolDefinition {
20    /// Name of the tool
21    pub name: String,
22    /// Description for the LLM
23    pub description: String,
24    /// JSON Schema for parameters
25    pub parameters: serde_json::Value,
26}
27
28/// Trait for implementing tools that AI agents can call
29#[async_trait]
30pub trait Tool: Send + Sync {
31    /// The name of this tool
32    /// The name of this tool
33    fn name(&self) -> String;
34
35    /// Get the tool definition for the LLM
36    async fn definition(&self) -> ToolDefinition;
37
38    /// Execute the tool with the given arguments (JSON string)
39    async fn call(&self, arguments: &str) -> anyhow::Result<String>;
40}
41
42/// A collection of tools available to an agent
43pub struct ToolSet {
44    tools: HashMap<String, Arc<dyn Tool>>,
45}
46
47impl Default for ToolSet {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl ToolSet {
54    /// Create an empty toolset
55    pub fn new() -> Self {
56        Self {
57            tools: HashMap::new(),
58        }
59    }
60
61    /// Add a tool to the set
62    pub fn add<T: Tool + 'static>(&mut self, tool: T) -> &mut Self {
63        self.tools.insert(tool.name().to_string(), Arc::new(tool));
64        self
65    }
66
67    /// Add a shared tool to the set
68    pub fn add_shared(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
69        self.tools.insert(tool.name().to_string(), tool);
70        self
71    }
72
73    /// Get a tool by name
74    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
75        self.tools.get(name)
76    }
77
78    /// Check if a tool exists
79    pub fn contains(&self, name: &str) -> bool {
80        self.tools.contains_key(name)
81    }
82
83    /// Get all tool definitions
84    pub async fn definitions(&self) -> Vec<ToolDefinition> {
85        let mut defs = Vec::new();
86        for tool in self.tools.values() {
87            defs.push(tool.definition().await);
88        }
89        defs
90    }
91
92    /// Call a tool by name
93    pub async fn call(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
94        let tool = self
95            .tools
96            .get(name)
97            .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
98
99        tool.call(arguments).await
100    }
101
102    /// Get the number of tools
103    pub fn len(&self) -> usize {
104        self.tools.len()
105    }
106
107    /// Check if empty
108    pub fn is_empty(&self) -> bool {
109        self.tools.is_empty()
110    }
111
112    /// Iterate over tools
113    pub fn iter(&self) -> impl Iterator<Item = (&String, &Arc<dyn Tool>)> {
114        self.tools.iter()
115    }
116}
117
118/// Builder for creating a ToolSet
119pub struct ToolSetBuilder {
120    tools: Vec<Arc<dyn Tool>>,
121}
122
123impl Default for ToolSetBuilder {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl ToolSetBuilder {
130    /// Create a new builder
131    pub fn new() -> Self {
132        Self { tools: Vec::new() }
133    }
134
135    /// Add a tool
136    pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
137        self.tools.push(Arc::new(tool));
138        self
139    }
140
141    /// Add a shared tool
142    pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
143        self.tools.push(tool);
144        self
145    }
146
147    /// Build the ToolSet
148    pub fn build(self) -> ToolSet {
149        let mut toolset = ToolSet::new();
150        for tool in self.tools {
151            toolset.add_shared(tool);
152        }
153        toolset
154    }
155}
156
157/// Helper macro for creating simple tools
158/// 
159/// # Example
160/// ```ignore
161/// simple_tool!(
162///     name: "get_time",
163///     description: "Get the current time",
164///     handler: |_args| async {
165///         Ok(chrono::Utc::now().to_rfc3339())
166///     }
167/// );
168/// ```
169#[macro_export]
170macro_rules! simple_tool {
171    (
172        name: $name:expr,
173        description: $desc:expr,
174        parameters: $params:expr,
175        handler: $handler:expr
176    ) => {{
177        struct SimpleTool;
178
179        #[async_trait::async_trait]
180        impl $crate::tool::Tool for SimpleTool {
181            fn name(&self) -> String {
182                $name.to_string()
183            }
184
185            async fn definition(&self) -> $crate::tool::ToolDefinition {
186                $crate::tool::ToolDefinition {
187                    name: $name.to_string(),
188                    description: $desc.to_string(),
189                    parameters: $params,
190                }
191            }
192
193            async fn call(&self, arguments: &str) -> anyhow::Result<String> {
194                let handler = $handler;
195                handler(arguments).await
196            }
197        }
198
199        SimpleTool
200    }};
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    struct EchoTool;
208
209    #[async_trait]
210    impl Tool for EchoTool {
211        fn name(&self) -> String {
212            "echo".to_string()
213        }
214
215        async fn definition(&self) -> ToolDefinition {
216            ToolDefinition {
217                name: "echo".to_string(),
218                description: "Echo back the input".to_string(),
219                parameters: serde_json::json!({
220                    "type": "object",
221                    "properties": {
222                        "message": {
223                            "type": "string",
224                            "description": "Message to echo"
225                        }
226                    },
227                    "required": ["message"]
228                }),
229            }
230        }
231
232        async fn call(&self, arguments: &str) -> anyhow::Result<String> {
233            #[derive(Deserialize)]
234            struct Args {
235                message: String,
236            }
237            let args: Args = serde_json::from_str(arguments)
238                .map_err(|e| Error::ToolArguments {
239                    tool_name: "echo".to_string(),
240                    message: e.to_string(),
241                })?;
242            Ok(args.message)
243        }
244    }
245
246    #[tokio::test]
247    async fn test_toolset() {
248        let mut toolset = ToolSet::new();
249        toolset.add(EchoTool);
250
251        assert!(toolset.contains("echo"));
252        assert_eq!(toolset.len(), 1);
253
254        let result = toolset
255            .call("echo", r#"{"message": "hello"}"#)
256            .await
257            .expect("call should succeed");
258        assert_eq!(result, "hello");
259    }
260}