helios_engine/
tools.rs

1use crate::error::{HeliosError, Result};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ToolParameter {
9    #[serde(rename = "type")]
10    pub param_type: String,
11    pub description: String,
12    #[serde(skip)]
13    pub required: Option<bool>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ToolDefinition {
18    #[serde(rename = "type")]
19    pub tool_type: String,
20    pub function: FunctionDefinition,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct FunctionDefinition {
25    pub name: String,
26    pub description: String,
27    pub parameters: ParametersSchema,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ParametersSchema {
32    #[serde(rename = "type")]
33    pub schema_type: String,
34    pub properties: HashMap<String, ToolParameter>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub required: Option<Vec<String>>,
37}
38
39#[derive(Debug, Clone)]
40pub struct ToolResult {
41    pub success: bool,
42    pub output: String,
43}
44
45impl ToolResult {
46    pub fn success(output: impl Into<String>) -> Self {
47        Self {
48            success: true,
49            output: output.into(),
50        }
51    }
52
53    pub fn error(message: impl Into<String>) -> Self {
54        Self {
55            success: false,
56            output: message.into(),
57        }
58    }
59}
60
61#[async_trait]
62pub trait Tool: Send + Sync {
63    fn name(&self) -> &str;
64    fn description(&self) -> &str;
65    fn parameters(&self) -> HashMap<String, ToolParameter>;
66    async fn execute(&self, args: Value) -> Result<ToolResult>;
67
68    fn to_definition(&self) -> ToolDefinition {
69        let required: Vec<String> = self
70            .parameters()
71            .iter()
72            .filter(|(_, param)| param.required.unwrap_or(false))
73            .map(|(name, _)| name.clone())
74            .collect();
75
76        ToolDefinition {
77            tool_type: "function".to_string(),
78            function: FunctionDefinition {
79                name: self.name().to_string(),
80                description: self.description().to_string(),
81                parameters: ParametersSchema {
82                    schema_type: "object".to_string(),
83                    properties: self.parameters(),
84                    required: if required.is_empty() {
85                        None
86                    } else {
87                        Some(required)
88                    },
89                },
90            },
91        }
92    }
93}
94
95pub struct ToolRegistry {
96    tools: HashMap<String, Box<dyn Tool>>,
97}
98
99impl ToolRegistry {
100    pub fn new() -> Self {
101        Self {
102            tools: HashMap::new(),
103        }
104    }
105
106    pub fn register(&mut self, tool: Box<dyn Tool>) {
107        let name = tool.name().to_string();
108        self.tools.insert(name, tool);
109    }
110
111    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
112        self.tools.get(name).map(|b| &**b)
113    }
114
115    pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
116        let tool = self
117            .tools
118            .get(name)
119            .ok_or_else(|| HeliosError::ToolError(format!("Tool '{}' not found", name)))?;
120
121        tool.execute(args).await
122    }
123
124    pub fn get_definitions(&self) -> Vec<ToolDefinition> {
125        self.tools
126            .values()
127            .map(|tool| tool.to_definition())
128            .collect()
129    }
130
131    pub fn list_tools(&self) -> Vec<String> {
132        self.tools.keys().cloned().collect()
133    }
134}
135
136impl Default for ToolRegistry {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142// Example built-in tools
143
144pub struct CalculatorTool;
145
146#[async_trait]
147impl Tool for CalculatorTool {
148    fn name(&self) -> &str {
149        "calculator"
150    }
151
152    fn description(&self) -> &str {
153        "Perform basic arithmetic operations. Supports +, -, *, / operations."
154    }
155
156    fn parameters(&self) -> HashMap<String, ToolParameter> {
157        let mut params = HashMap::new();
158        params.insert(
159            "expression".to_string(),
160            ToolParameter {
161                param_type: "string".to_string(),
162                description: "Mathematical expression to evaluate (e.g., '2 + 2')".to_string(),
163                required: Some(true),
164            },
165        );
166        params
167    }
168
169    async fn execute(&self, args: Value) -> Result<ToolResult> {
170        let expression = args
171            .get("expression")
172            .and_then(|v| v.as_str())
173            .ok_or_else(|| HeliosError::ToolError("Missing 'expression' parameter".to_string()))?;
174
175        // Simple expression evaluator
176        let result = evaluate_expression(expression)?;
177        Ok(ToolResult::success(result.to_string()))
178    }
179}
180
181fn evaluate_expression(expr: &str) -> Result<f64> {
182    let expr = expr.replace(" ", "");
183
184    // Simple parsing for basic operations
185    for op in &['*', '/', '+', '-'] {
186        if let Some(pos) = expr.rfind(*op) {
187            if pos == 0 {
188                continue; // Skip if operator is at the beginning (negative number)
189            }
190            let left = &expr[..pos];
191            let right = &expr[pos + 1..];
192
193            let left_val = evaluate_expression(left)?;
194            let right_val = evaluate_expression(right)?;
195
196            return Ok(match op {
197                '+' => left_val + right_val,
198                '-' => left_val - right_val,
199                '*' => left_val * right_val,
200                '/' => {
201                    if right_val == 0.0 {
202                        return Err(HeliosError::ToolError("Division by zero".to_string()));
203                    }
204                    left_val / right_val
205                }
206                _ => unreachable!(),
207            });
208        }
209    }
210
211    expr.parse::<f64>()
212        .map_err(|_| HeliosError::ToolError(format!("Invalid expression: {}", expr)))
213}
214
215pub struct EchoTool;
216
217#[async_trait]
218impl Tool for EchoTool {
219    fn name(&self) -> &str {
220        "echo"
221    }
222
223    fn description(&self) -> &str {
224        "Echo back the provided message."
225    }
226
227    fn parameters(&self) -> HashMap<String, ToolParameter> {
228        let mut params = HashMap::new();
229        params.insert(
230            "message".to_string(),
231            ToolParameter {
232                param_type: "string".to_string(),
233                description: "The message to echo back".to_string(),
234                required: Some(true),
235            },
236        );
237        params
238    }
239
240    async fn execute(&self, args: Value) -> Result<ToolResult> {
241        let message = args
242            .get("message")
243            .and_then(|v| v.as_str())
244            .ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?;
245
246        Ok(ToolResult::success(format!("Echo: {}", message)))
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use serde_json::json;
254
255    #[test]
256    fn test_tool_result_success() {
257        let result = ToolResult::success("test output");
258        assert!(result.success);
259        assert_eq!(result.output, "test output");
260    }
261
262    #[test]
263    fn test_tool_result_error() {
264        let result = ToolResult::error("test error");
265        assert!(!result.success);
266        assert_eq!(result.output, "test error");
267    }
268
269    #[tokio::test]
270    async fn test_calculator_tool() {
271        let tool = CalculatorTool;
272        assert_eq!(tool.name(), "calculator");
273        assert_eq!(
274            tool.description(),
275            "Perform basic arithmetic operations. Supports +, -, *, / operations."
276);
277
278        let args = json!({"expression": "2 + 2"});
279        let result = tool.execute(args).await.unwrap();
280        assert!(result.success);
281        assert_eq!(result.output, "4");
282    }
283
284    #[tokio::test]
285    async fn test_calculator_tool_multiplication() {
286        let tool = CalculatorTool;
287        let args = json!({"expression": "3 * 4"});
288        let result = tool.execute(args).await.unwrap();
289        assert!(result.success);
290        assert_eq!(result.output, "12");
291    }
292
293    #[tokio::test]
294    async fn test_calculator_tool_division() {
295        let tool = CalculatorTool;
296        let args = json!({"expression": "8 / 2"});
297        let result = tool.execute(args).await.unwrap();
298        assert!(result.success);
299        assert_eq!(result.output, "4");
300    }
301
302    #[tokio::test]
303    async fn test_calculator_tool_division_by_zero() {
304        let tool = CalculatorTool;
305        let args = json!({"expression": "8 / 0"});
306        let result = tool.execute(args).await;
307        assert!(result.is_err());
308    }
309
310    #[tokio::test]
311    async fn test_calculator_tool_invalid_expression() {
312        let tool = CalculatorTool;
313        let args = json!({"expression": "invalid"});
314        let result = tool.execute(args).await;
315        assert!(result.is_err());
316    }
317
318    #[tokio::test]
319    async fn test_echo_tool() {
320        let tool = EchoTool;
321        assert_eq!(tool.name(), "echo");
322        assert_eq!(tool.description(), "Echo back the provided message.");
323
324        let args = json!({"message": "Hello, world!"});
325        let result = tool.execute(args).await.unwrap();
326        assert!(result.success);
327        assert_eq!(result.output, "Echo: Hello, world!");
328    }
329
330    #[tokio::test]
331    async fn test_echo_tool_missing_parameter() {
332        let tool = EchoTool;
333        let args = json!({});
334        let result = tool.execute(args).await;
335        assert!(result.is_err());
336    }
337
338    #[test]
339    fn test_tool_registry_new() {
340        let registry = ToolRegistry::new();
341        assert!(registry.tools.is_empty());
342    }
343
344    #[tokio::test]
345    async fn test_tool_registry_register_and_get() {
346        let mut registry = ToolRegistry::new();
347        registry.register(Box::new(CalculatorTool));
348
349        let tool = registry.get("calculator");
350        assert!(tool.is_some());
351        assert_eq!(tool.unwrap().name(), "calculator");
352    }
353
354    #[tokio::test]
355    async fn test_tool_registry_execute() {
356        let mut registry = ToolRegistry::new();
357        registry.register(Box::new(CalculatorTool));
358
359        let args = json!({"expression": "5 * 6"});
360        let result = registry.execute("calculator", args).await.unwrap();
361        assert!(result.success);
362        assert_eq!(result.output, "30");
363    }
364
365    #[tokio::test]
366    async fn test_tool_registry_execute_nonexistent_tool() {
367        let registry = ToolRegistry::new();
368        let args = json!({"expression": "5 * 6"});
369        let result = registry.execute("nonexistent", args).await;
370        assert!(result.is_err());
371    }
372
373    #[test]
374    fn test_tool_registry_get_definitions() {
375        let mut registry = ToolRegistry::new();
376        registry.register(Box::new(CalculatorTool));
377        registry.register(Box::new(EchoTool));
378
379        let definitions = registry.get_definitions();
380        assert_eq!(definitions.len(), 2);
381
382        // Check that we have both tools
383        let names: Vec<String> = definitions
384            .iter()
385            .map(|d| d.function.name.clone())
386            .collect();
387        assert!(names.contains(&"calculator".to_string()));
388        assert!(names.contains(&"echo".to_string()));
389    }
390
391    #[test]
392    fn test_tool_registry_list_tools() {
393        let mut registry = ToolRegistry::new();
394        registry.register(Box::new(CalculatorTool));
395        registry.register(Box::new(EchoTool));
396
397        let tools = registry.list_tools();
398        assert_eq!(tools.len(), 2);
399        assert!(tools.contains(&"calculator".to_string()));
400        assert!(tools.contains(&"echo".to_string()));
401    }
402}