Skip to main content

aagt_core/
tool.rs

1//! Tool system for AI agents
2//!
3//! Provides the core abstraction for defining tools that AI agents can call.
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::{Error, Result};
11
12/// Definition of a tool that can be sent to the LLM
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolDefinition {
15    /// Name of the tool
16    pub name: String,
17    /// Description for the LLM
18    pub description: String,
19    /// JSON Schema for parameters
20    pub parameters: serde_json::Value,
21}
22
23/// Trait for implementing tools that AI agents can call
24#[async_trait]
25pub trait Tool: Send + Sync {
26    /// The name of this tool
27    /// The name of this tool
28    fn name(&self) -> String;
29
30    /// Get the tool definition for the LLM
31    async fn definition(&self) -> ToolDefinition;
32
33    /// Execute the tool with the given arguments (JSON string)
34    async fn call(&self, arguments: &str) -> Result<String>;
35}
36
37/// A collection of tools available to an agent
38pub struct ToolSet {
39    tools: HashMap<String, Arc<dyn Tool>>,
40}
41
42impl Default for ToolSet {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl ToolSet {
49    /// Create an empty toolset
50    pub fn new() -> Self {
51        Self {
52            tools: HashMap::new(),
53        }
54    }
55
56    /// Add a tool to the set
57    pub fn add<T: Tool + 'static>(&mut self, tool: T) -> &mut Self {
58        self.tools.insert(tool.name().to_string(), Arc::new(tool));
59        self
60    }
61
62    /// Add a shared tool to the set
63    pub fn add_shared(&mut self, tool: Arc<dyn Tool>) -> &mut Self {
64        self.tools.insert(tool.name().to_string(), tool);
65        self
66    }
67
68    /// Get a tool by name
69    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
70        self.tools.get(name)
71    }
72
73    /// Check if a tool exists
74    pub fn contains(&self, name: &str) -> bool {
75        self.tools.contains_key(name)
76    }
77
78    /// Get all tool definitions
79    pub async fn definitions(&self) -> Vec<ToolDefinition> {
80        let mut defs = Vec::new();
81        for tool in self.tools.values() {
82            defs.push(tool.definition().await);
83        }
84        defs
85    }
86
87    /// Call a tool by name
88    pub async fn call(&self, name: &str, arguments: &str) -> Result<String> {
89        let tool = self
90            .tools
91            .get(name)
92            .ok_or_else(|| Error::ToolNotFound(name.to_string()))?;
93
94        tool.call(arguments).await
95    }
96
97    /// Get the number of tools
98    pub fn len(&self) -> usize {
99        self.tools.len()
100    }
101
102    /// Check if empty
103    pub fn is_empty(&self) -> bool {
104        self.tools.is_empty()
105    }
106
107    /// Iterate over tools
108    pub fn iter(&self) -> impl Iterator<Item = (&String, &Arc<dyn Tool>)> {
109        self.tools.iter()
110    }
111}
112
113/// Builder for creating a ToolSet
114pub struct ToolSetBuilder {
115    tools: Vec<Arc<dyn Tool>>,
116}
117
118impl Default for ToolSetBuilder {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl ToolSetBuilder {
125    /// Create a new builder
126    pub fn new() -> Self {
127        Self { tools: Vec::new() }
128    }
129
130    /// Add a tool
131    pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
132        self.tools.push(Arc::new(tool));
133        self
134    }
135
136    /// Add a shared tool
137    pub fn shared_tool(mut self, tool: Arc<dyn Tool>) -> Self {
138        self.tools.push(tool);
139        self
140    }
141
142    /// Build the ToolSet
143    pub fn build(self) -> ToolSet {
144        let mut toolset = ToolSet::new();
145        for tool in self.tools {
146            toolset.add_shared(tool);
147        }
148        toolset
149    }
150}
151
152/// Helper macro for creating simple tools
153/// 
154/// # Example
155/// ```ignore
156/// simple_tool!(
157///     name: "get_time",
158///     description: "Get the current time",
159///     handler: |_args| async {
160///         Ok(chrono::Utc::now().to_rfc3339())
161///     }
162/// );
163/// ```
164#[macro_export]
165macro_rules! simple_tool {
166    (
167        name: $name:expr,
168        description: $desc:expr,
169        parameters: $params:expr,
170        handler: $handler:expr
171    ) => {{
172        struct SimpleTool;
173
174        #[async_trait::async_trait]
175        impl $crate::tool::Tool for SimpleTool {
176            fn name(&self) -> String {
177                $name.to_string()
178            }
179
180            async fn definition(&self) -> $crate::tool::ToolDefinition {
181                $crate::tool::ToolDefinition {
182                    name: $name.to_string(),
183                    description: $desc.to_string(),
184                    parameters: $params,
185                }
186            }
187
188            async fn call(&self, arguments: &str) -> $crate::error::Result<String> {
189                let handler = $handler;
190                handler(arguments).await
191            }
192        }
193
194        SimpleTool
195    }};
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    struct EchoTool;
203
204    #[async_trait]
205    impl Tool for EchoTool {
206        fn name(&self) -> String {
207            "echo".to_string()
208        }
209
210        async fn definition(&self) -> ToolDefinition {
211            ToolDefinition {
212                name: "echo".to_string(),
213                description: "Echo back the input".to_string(),
214                parameters: serde_json::json!({
215                    "type": "object",
216                    "properties": {
217                        "message": {
218                            "type": "string",
219                            "description": "Message to echo"
220                        }
221                    },
222                    "required": ["message"]
223                }),
224            }
225        }
226
227        async fn call(&self, arguments: &str) -> Result<String> {
228            #[derive(Deserialize)]
229            struct Args {
230                message: String,
231            }
232            let args: Args = serde_json::from_str(arguments)
233                .map_err(|e| Error::ToolArguments {
234                    tool_name: "echo".to_string(),
235                    message: e.to_string(),
236                })?;
237            Ok(args.message)
238        }
239    }
240
241    #[tokio::test]
242    async fn test_toolset() {
243        let mut toolset = ToolSet::new();
244        toolset.add(EchoTool);
245
246        assert!(toolset.contains("echo"));
247        assert_eq!(toolset.len(), 1);
248
249        let result = toolset
250            .call("echo", r#"{"message": "hello"}"#)
251            .await
252            .expect("call should succeed");
253        assert_eq!(result, "hello");
254    }
255}