Skip to main content

agent_io/tools/
tool.rs

1//! Tool trait and implementations
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize, de::DeserializeOwned};
5use serde_json::{Value, json};
6
7use crate::Result;
8use crate::llm::ToolDefinition;
9
10/// Tool execution result
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ToolResult {
13    /// Tool call ID
14    pub tool_call_id: String,
15    /// Result content
16    pub content: String,
17    /// Whether this result should be ephemeral (removed after use)
18    #[serde(default)]
19    pub ephemeral: bool,
20}
21
22impl ToolResult {
23    pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
24        Self {
25            tool_call_id: tool_call_id.into(),
26            content: content.into(),
27            ephemeral: false,
28        }
29    }
30
31    pub fn with_ephemeral(mut self, ephemeral: bool) -> Self {
32        self.ephemeral = ephemeral;
33        self
34    }
35}
36
37/// Trait for defining tools that can be called by an LLM
38#[async_trait]
39pub trait Tool: Send + Sync {
40    /// Get the tool name
41    fn name(&self) -> &str;
42
43    /// Get the tool description
44    fn description(&self) -> &str;
45
46    /// Get the tool definition (JSON Schema)
47    fn definition(&self) -> ToolDefinition;
48
49    /// Execute the tool with given arguments
50    async fn execute(
51        &self,
52        args: Value,
53        overrides: Option<DependencyOverrides>,
54    ) -> Result<ToolResult>;
55
56    /// Whether tool outputs should be ephemeral (removed from context after use)
57    fn ephemeral(&self) -> EphemeralConfig {
58        EphemeralConfig::None
59    }
60}
61
62/// Configuration for ephemeral tool outputs
63#[derive(Debug, Clone, Copy, PartialEq, Default)]
64pub enum EphemeralConfig {
65    /// Not ephemeral
66    #[default]
67    None,
68    /// Ephemeral, removed after one use
69    Single,
70    /// Keep last N outputs in context
71    Count(usize),
72}
73
74/// A tool implementation using a function
75pub struct FunctionTool<T, F>
76where
77    T: DeserializeOwned + Send + Sync + 'static,
78    F: Fn(T) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
79        + Send
80        + Sync,
81{
82    name: String,
83    description: String,
84    parameters_schema: serde_json::Map<String, Value>,
85    func: F,
86    ephemeral_config: EphemeralConfig,
87    _marker: std::marker::PhantomData<T>,
88}
89
90impl<T, F> FunctionTool<T, F>
91where
92    T: DeserializeOwned + Send + Sync + 'static,
93    F: Fn(T) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
94        + Send
95        + Sync,
96{
97    /// Create a new function tool
98    pub fn new(
99        name: impl Into<String>,
100        description: impl Into<String>,
101        parameters_schema: serde_json::Map<String, Value>,
102        func: F,
103    ) -> Self {
104        Self {
105            name: name.into(),
106            description: description.into(),
107            parameters_schema,
108            func,
109            ephemeral_config: EphemeralConfig::None,
110            _marker: std::marker::PhantomData,
111        }
112    }
113
114    /// Set ephemeral configuration
115    pub fn with_ephemeral(mut self, config: EphemeralConfig) -> Self {
116        self.ephemeral_config = config;
117        self
118    }
119}
120
121#[async_trait]
122impl<T, F> Tool for FunctionTool<T, F>
123where
124    T: DeserializeOwned + Send + Sync + 'static,
125    F: Fn(T) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
126        + Send
127        + Sync,
128{
129    fn name(&self) -> &str {
130        &self.name
131    }
132
133    fn description(&self) -> &str {
134        &self.description
135    }
136
137    fn definition(&self) -> ToolDefinition {
138        ToolDefinition::new(
139            &self.name,
140            &self.description,
141            self.parameters_schema.clone(),
142        )
143    }
144
145    async fn execute(
146        &self,
147        args: Value,
148        _overrides: Option<DependencyOverrides>,
149    ) -> Result<ToolResult> {
150        let parsed: T = serde_json::from_value(args)?;
151        let content = (self.func)(parsed).await?;
152        Ok(ToolResult::new("", content)
153            .with_ephemeral(self.ephemeral_config != EphemeralConfig::None))
154    }
155
156    fn ephemeral(&self) -> EphemeralConfig {
157        self.ephemeral_config
158    }
159}
160
161/// Builder for creating tools
162pub struct ToolBuilder {
163    name: String,
164    description: String,
165    parameters_schema: serde_json::Map<String, Value>,
166    ephemeral: EphemeralConfig,
167}
168
169impl ToolBuilder {
170    pub fn new(name: impl Into<String>) -> Self {
171        Self {
172            name: name.into(),
173            description: String::new(),
174            parameters_schema: serde_json::Map::new(),
175            ephemeral: EphemeralConfig::None,
176        }
177    }
178
179    pub fn description(mut self, desc: impl Into<String>) -> Self {
180        self.description = desc.into();
181        self
182    }
183
184    pub fn parameter(mut self, name: &str, schema: Value) -> Self {
185        self.parameters_schema.insert(name.to_string(), schema);
186        self
187    }
188
189    pub fn string_param(self, name: &str, description: &str) -> Self {
190        self.parameter(
191            name,
192            json!({
193                "type": "string",
194                "description": description
195            }),
196        )
197    }
198
199    pub fn number_param(self, name: &str, description: &str) -> Self {
200        self.parameter(
201            name,
202            json!({
203                "type": "number",
204                "description": description
205            }),
206        )
207    }
208
209    pub fn boolean_param(self, name: &str, description: &str) -> Self {
210        self.parameter(
211            name,
212            json!({
213                "type": "boolean",
214                "description": description
215            }),
216        )
217    }
218
219    pub fn ephemeral(mut self, config: EphemeralConfig) -> Self {
220        self.ephemeral = config;
221        self
222    }
223
224    pub fn build<F, T>(self, func: F) -> Box<dyn Tool>
225    where
226        T: DeserializeOwned + Send + Sync + 'static,
227        F: Fn(T) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
228            + Send
229            + Sync
230            + 'static,
231    {
232        let mut tool = FunctionTool::new(self.name, self.description, self.parameters_schema, func);
233        tool.ephemeral_config = self.ephemeral;
234        Box::new(tool)
235    }
236}
237
238/// Dependency overrides for testing
239pub type DependencyOverrides =
240    std::collections::HashMap<String, Box<dyn std::any::Any + Send + Sync>>;
241
242/// Simple tool that takes no arguments
243pub struct SimpleTool<F>
244where
245    F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
246        + Send
247        + Sync,
248{
249    name: String,
250    description: String,
251    func: F,
252}
253
254impl<F> SimpleTool<F>
255where
256    F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
257        + Send
258        + Sync,
259{
260    pub fn new(name: impl Into<String>, description: impl Into<String>, func: F) -> Self {
261        Self {
262            name: name.into(),
263            description: description.into(),
264            func,
265        }
266    }
267}
268
269#[async_trait]
270impl<F> Tool for SimpleTool<F>
271where
272    F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send>>
273        + Send
274        + Sync,
275{
276    fn name(&self) -> &str {
277        &self.name
278    }
279
280    fn description(&self) -> &str {
281        &self.description
282    }
283
284    fn definition(&self) -> ToolDefinition {
285        ToolDefinition::new(&self.name, &self.description, serde_json::Map::new())
286    }
287
288    async fn execute(
289        &self,
290        _args: Value,
291        _overrides: Option<DependencyOverrides>,
292    ) -> Result<ToolResult> {
293        let content = (self.func)().await?;
294        Ok(ToolResult::new("", content))
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[tokio::test]
303    async fn test_simple_tool() {
304        let tool = SimpleTool::new("ping", "Returns pong", || {
305            Box::pin(async { Ok("pong".to_string()) })
306        });
307
308        assert_eq!(tool.name(), "ping");
309
310        let result = tool.execute(json!({}), None).await.unwrap();
311        assert_eq!(result.content, "pong");
312    }
313
314    #[tokio::test]
315    async fn test_function_tool() {
316        #[derive(Deserialize)]
317        struct EchoArgs {
318            message: String,
319        }
320
321        let tool = FunctionTool::new(
322            "echo",
323            "Echoes the message back",
324            json!({
325                "type": "object",
326                "properties": {
327                    "message": { "type": "string" }
328                },
329                "required": ["message"]
330            })
331            .as_object()
332            .unwrap()
333            .clone(),
334            |args: EchoArgs| Box::pin(async move { Ok(args.message) }),
335        );
336
337        let result = tool
338            .execute(json!({"message": "hello"}), None)
339            .await
340            .unwrap();
341        assert_eq!(result.content, "hello");
342    }
343}