1use crate::error::{HeliosError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ToolParameter {
9 #[serde(rename = "type")]
10 pub param_type: String,
11 pub description: String,
12 #[serde(skip)]
13 pub required: Option<bool>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ToolDefinition {
18 #[serde(rename = "type")]
19 pub tool_type: String,
20 pub function: FunctionDefinition,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct FunctionDefinition {
25 pub name: String,
26 pub description: String,
27 pub parameters: ParametersSchema,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ParametersSchema {
32 #[serde(rename = "type")]
33 pub schema_type: String,
34 pub properties: HashMap<String, ToolParameter>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub required: Option<Vec<String>>,
37}
38
39#[derive(Debug, Clone)]
40pub struct ToolResult {
41 pub success: bool,
42 pub output: String,
43}
44
45impl ToolResult {
46 pub fn success(output: impl Into<String>) -> Self {
47 Self {
48 success: true,
49 output: output.into(),
50 }
51 }
52
53 pub fn error(message: impl Into<String>) -> Self {
54 Self {
55 success: false,
56 output: message.into(),
57 }
58 }
59}
60
61#[async_trait]
62pub trait Tool: Send + Sync {
63 fn name(&self) -> &str;
64 fn description(&self) -> &str;
65 fn parameters(&self) -> HashMap<String, ToolParameter>;
66 async fn execute(&self, args: Value) -> Result<ToolResult>;
67
68 fn to_definition(&self) -> ToolDefinition {
69 let required: Vec<String> = self
70 .parameters()
71 .iter()
72 .filter(|(_, param)| param.required.unwrap_or(false))
73 .map(|(name, _)| name.clone())
74 .collect();
75
76 ToolDefinition {
77 tool_type: "function".to_string(),
78 function: FunctionDefinition {
79 name: self.name().to_string(),
80 description: self.description().to_string(),
81 parameters: ParametersSchema {
82 schema_type: "object".to_string(),
83 properties: self.parameters(),
84 required: if required.is_empty() {
85 None
86 } else {
87 Some(required)
88 },
89 },
90 },
91 }
92 }
93}
94
95pub struct ToolRegistry {
96 tools: HashMap<String, Box<dyn Tool>>,
97}
98
99impl ToolRegistry {
100 pub fn new() -> Self {
101 Self {
102 tools: HashMap::new(),
103 }
104 }
105
106 pub fn register(&mut self, tool: Box<dyn Tool>) {
107 let name = tool.name().to_string();
108 self.tools.insert(name, tool);
109 }
110
111 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
112 self.tools.get(name).map(|b| &**b)
113 }
114
115 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
116 let tool = self
117 .tools
118 .get(name)
119 .ok_or_else(|| HeliosError::ToolError(format!("Tool '{}' not found", name)))?;
120
121 tool.execute(args).await
122 }
123
124 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
125 self.tools
126 .values()
127 .map(|tool| tool.to_definition())
128 .collect()
129 }
130
131 pub fn list_tools(&self) -> Vec<String> {
132 self.tools.keys().cloned().collect()
133 }
134}
135
136impl Default for ToolRegistry {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142pub struct CalculatorTool;
145
146#[async_trait]
147impl Tool for CalculatorTool {
148 fn name(&self) -> &str {
149 "calculator"
150 }
151
152 fn description(&self) -> &str {
153 "Perform basic arithmetic operations. Supports +, -, *, / operations."
154 }
155
156 fn parameters(&self) -> HashMap<String, ToolParameter> {
157 let mut params = HashMap::new();
158 params.insert(
159 "expression".to_string(),
160 ToolParameter {
161 param_type: "string".to_string(),
162 description: "Mathematical expression to evaluate (e.g., '2 + 2')".to_string(),
163 required: Some(true),
164 },
165 );
166 params
167 }
168
169 async fn execute(&self, args: Value) -> Result<ToolResult> {
170 let expression = args
171 .get("expression")
172 .and_then(|v| v.as_str())
173 .ok_or_else(|| HeliosError::ToolError("Missing 'expression' parameter".to_string()))?;
174
175 let result = evaluate_expression(expression)?;
177 Ok(ToolResult::success(result.to_string()))
178 }
179}
180
181fn evaluate_expression(expr: &str) -> Result<f64> {
182 let expr = expr.replace(" ", "");
183
184 for op in &['*', '/', '+', '-'] {
186 if let Some(pos) = expr.rfind(*op) {
187 if pos == 0 {
188 continue; }
190 let left = &expr[..pos];
191 let right = &expr[pos + 1..];
192
193 let left_val = evaluate_expression(left)?;
194 let right_val = evaluate_expression(right)?;
195
196 return Ok(match op {
197 '+' => left_val + right_val,
198 '-' => left_val - right_val,
199 '*' => left_val * right_val,
200 '/' => {
201 if right_val == 0.0 {
202 return Err(HeliosError::ToolError("Division by zero".to_string()));
203 }
204 left_val / right_val
205 }
206 _ => unreachable!(),
207 });
208 }
209 }
210
211 expr.parse::<f64>()
212 .map_err(|_| HeliosError::ToolError(format!("Invalid expression: {}", expr)))
213}
214
215pub struct EchoTool;
216
217#[async_trait]
218impl Tool for EchoTool {
219 fn name(&self) -> &str {
220 "echo"
221 }
222
223 fn description(&self) -> &str {
224 "Echo back the provided message."
225 }
226
227 fn parameters(&self) -> HashMap<String, ToolParameter> {
228 let mut params = HashMap::new();
229 params.insert(
230 "message".to_string(),
231 ToolParameter {
232 param_type: "string".to_string(),
233 description: "The message to echo back".to_string(),
234 required: Some(true),
235 },
236 );
237 params
238 }
239
240 async fn execute(&self, args: Value) -> Result<ToolResult> {
241 let message = args
242 .get("message")
243 .and_then(|v| v.as_str())
244 .ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?;
245
246 Ok(ToolResult::success(format!("Echo: {}", message)))
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use serde_json::json;
254
255 #[test]
256 fn test_tool_result_success() {
257 let result = ToolResult::success("test output");
258 assert!(result.success);
259 assert_eq!(result.output, "test output");
260 }
261
262 #[test]
263 fn test_tool_result_error() {
264 let result = ToolResult::error("test error");
265 assert!(!result.success);
266 assert_eq!(result.output, "test error");
267 }
268
269 #[tokio::test]
270 async fn test_calculator_tool() {
271 let tool = CalculatorTool;
272 assert_eq!(tool.name(), "calculator");
273 assert_eq!(
274 tool.description(),
275 "Perform basic arithmetic operations. Supports +, -, *, / operations."
276);
277
278 let args = json!({"expression": "2 + 2"});
279 let result = tool.execute(args).await.unwrap();
280 assert!(result.success);
281 assert_eq!(result.output, "4");
282 }
283
284 #[tokio::test]
285 async fn test_calculator_tool_multiplication() {
286 let tool = CalculatorTool;
287 let args = json!({"expression": "3 * 4"});
288 let result = tool.execute(args).await.unwrap();
289 assert!(result.success);
290 assert_eq!(result.output, "12");
291 }
292
293 #[tokio::test]
294 async fn test_calculator_tool_division() {
295 let tool = CalculatorTool;
296 let args = json!({"expression": "8 / 2"});
297 let result = tool.execute(args).await.unwrap();
298 assert!(result.success);
299 assert_eq!(result.output, "4");
300 }
301
302 #[tokio::test]
303 async fn test_calculator_tool_division_by_zero() {
304 let tool = CalculatorTool;
305 let args = json!({"expression": "8 / 0"});
306 let result = tool.execute(args).await;
307 assert!(result.is_err());
308 }
309
310 #[tokio::test]
311 async fn test_calculator_tool_invalid_expression() {
312 let tool = CalculatorTool;
313 let args = json!({"expression": "invalid"});
314 let result = tool.execute(args).await;
315 assert!(result.is_err());
316 }
317
318 #[tokio::test]
319 async fn test_echo_tool() {
320 let tool = EchoTool;
321 assert_eq!(tool.name(), "echo");
322 assert_eq!(tool.description(), "Echo back the provided message.");
323
324 let args = json!({"message": "Hello, world!"});
325 let result = tool.execute(args).await.unwrap();
326 assert!(result.success);
327 assert_eq!(result.output, "Echo: Hello, world!");
328 }
329
330 #[tokio::test]
331 async fn test_echo_tool_missing_parameter() {
332 let tool = EchoTool;
333 let args = json!({});
334 let result = tool.execute(args).await;
335 assert!(result.is_err());
336 }
337
338 #[test]
339 fn test_tool_registry_new() {
340 let registry = ToolRegistry::new();
341 assert!(registry.tools.is_empty());
342 }
343
344 #[tokio::test]
345 async fn test_tool_registry_register_and_get() {
346 let mut registry = ToolRegistry::new();
347 registry.register(Box::new(CalculatorTool));
348
349 let tool = registry.get("calculator");
350 assert!(tool.is_some());
351 assert_eq!(tool.unwrap().name(), "calculator");
352 }
353
354 #[tokio::test]
355 async fn test_tool_registry_execute() {
356 let mut registry = ToolRegistry::new();
357 registry.register(Box::new(CalculatorTool));
358
359 let args = json!({"expression": "5 * 6"});
360 let result = registry.execute("calculator", args).await.unwrap();
361 assert!(result.success);
362 assert_eq!(result.output, "30");
363 }
364
365 #[tokio::test]
366 async fn test_tool_registry_execute_nonexistent_tool() {
367 let registry = ToolRegistry::new();
368 let args = json!({"expression": "5 * 6"});
369 let result = registry.execute("nonexistent", args).await;
370 assert!(result.is_err());
371 }
372
373 #[test]
374 fn test_tool_registry_get_definitions() {
375 let mut registry = ToolRegistry::new();
376 registry.register(Box::new(CalculatorTool));
377 registry.register(Box::new(EchoTool));
378
379 let definitions = registry.get_definitions();
380 assert_eq!(definitions.len(), 2);
381
382 let names: Vec<String> = definitions
384 .iter()
385 .map(|d| d.function.name.clone())
386 .collect();
387 assert!(names.contains(&"calculator".to_string()));
388 assert!(names.contains(&"echo".to_string()));
389 }
390
391 #[test]
392 fn test_tool_registry_list_tools() {
393 let mut registry = ToolRegistry::new();
394 registry.register(Box::new(CalculatorTool));
395 registry.register(Box::new(EchoTool));
396
397 let tools = registry.list_tools();
398 assert_eq!(tools.len(), 2);
399 assert!(tools.contains(&"calculator".to_string()));
400 assert!(tools.contains(&"echo".to_string()));
401 }
402}