Skip to main content

agent_runtime/
tool.rs

1use crate::types::{ToolError, ToolExecutionResult, ToolResult};
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use std::collections::HashMap;
5
6/// Tool trait that all tools must implement
7#[async_trait]
8pub trait Tool: Send + Sync {
9    /// Unique name for this tool
10    fn name(&self) -> &str;
11
12    /// Human-readable description for LLM
13    fn description(&self) -> &str;
14
15    /// JSON schema for input parameters
16    fn input_schema(&self) -> JsonValue;
17
18    /// Execute the tool with given parameters
19    async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult;
20}
21
22/// Example: Echo tool that returns its input
23pub struct EchoTool;
24
25#[async_trait]
26impl Tool for EchoTool {
27    fn name(&self) -> &str {
28        "echo"
29    }
30
31    fn description(&self) -> &str {
32        "Echoes back the input message"
33    }
34
35    fn input_schema(&self) -> JsonValue {
36        serde_json::json!({
37            "type": "object",
38            "properties": {
39                "message": {
40                    "type": "string",
41                    "description": "The message to echo"
42                }
43            },
44            "required": ["message"]
45        })
46    }
47
48    async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult {
49        let start = std::time::Instant::now();
50
51        let message = params
52            .get("message")
53            .and_then(|v| v.as_str())
54            .ok_or_else(|| ToolError::InvalidParameters("missing 'message' parameter".into()))?;
55
56        let output = serde_json::json!({
57            "echoed": message
58        });
59
60        Ok(ToolResult {
61            output,
62            duration_ms: start.elapsed().as_millis() as u64,
63        })
64    }
65}
66
67/// Example: Calculator tool for simple math
68pub struct CalculatorTool;
69
70#[async_trait]
71impl Tool for CalculatorTool {
72    fn name(&self) -> &str {
73        "calculator"
74    }
75
76    fn description(&self) -> &str {
77        "Performs basic arithmetic operations (add, subtract, multiply, divide)"
78    }
79
80    fn input_schema(&self) -> JsonValue {
81        serde_json::json!({
82            "type": "object",
83            "properties": {
84                "operation": {
85                    "type": "string",
86                    "enum": ["add", "subtract", "multiply", "divide"]
87                },
88                "a": { "type": "number" },
89                "b": { "type": "number" }
90            },
91            "required": ["operation", "a", "b"]
92        })
93    }
94
95    async fn execute(&self, params: HashMap<String, JsonValue>) -> ToolExecutionResult {
96        let start = std::time::Instant::now();
97
98        let operation = params
99            .get("operation")
100            .and_then(|v| v.as_str())
101            .ok_or_else(|| ToolError::InvalidParameters("missing 'operation'".into()))?;
102
103        let a = params
104            .get("a")
105            .and_then(|v| v.as_f64())
106            .ok_or_else(|| ToolError::InvalidParameters("missing 'a'".into()))?;
107
108        let b = params
109            .get("b")
110            .and_then(|v| v.as_f64())
111            .ok_or_else(|| ToolError::InvalidParameters("missing 'b'".into()))?;
112
113        let result = match operation {
114            "add" => a + b,
115            "subtract" => a - b,
116            "multiply" => a * b,
117            "divide" => {
118                if b == 0.0 {
119                    return Err(ToolError::ExecutionFailed("division by zero".into()));
120                }
121                a / b
122            }
123            _ => {
124                return Err(ToolError::InvalidParameters(format!(
125                    "unknown operation: {}",
126                    operation
127                )))
128            }
129        };
130
131        Ok(ToolResult {
132            output: serde_json::json!({ "result": result }),
133            duration_ms: start.elapsed().as_millis() as u64,
134        })
135    }
136}