helios_engine/
tools.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use crate::error::{HeliosError, Result};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ToolParameter {
9    #[serde(rename = "type")]
10    pub param_type: String,
11    pub description: String,
12    #[serde(skip)]
13    pub required: Option<bool>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ToolDefinition {
18    #[serde(rename = "type")]
19    pub tool_type: String,
20    pub function: FunctionDefinition,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct FunctionDefinition {
25    pub name: String,
26    pub description: String,
27    pub parameters: ParametersSchema,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ParametersSchema {
32    #[serde(rename = "type")]
33    pub schema_type: String,
34    pub properties: HashMap<String, ToolParameter>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub required: Option<Vec<String>>,
37}
38
39#[derive(Debug, Clone)]
40pub struct ToolResult {
41    pub success: bool,
42    pub output: String,
43}
44
45impl ToolResult {
46    pub fn success(output: impl Into<String>) -> Self {
47        Self {
48            success: true,
49            output: output.into(),
50        }
51    }
52
53    pub fn error(message: impl Into<String>) -> Self {
54        Self {
55            success: false,
56            output: message.into(),
57        }
58    }
59}
60
61#[async_trait]
62pub trait Tool: Send + Sync {
63    fn name(&self) -> &str;
64    fn description(&self) -> &str;
65    fn parameters(&self) -> HashMap<String, ToolParameter>;
66    async fn execute(&self, args: Value) -> Result<ToolResult>;
67
68    fn to_definition(&self) -> ToolDefinition {
69        let required: Vec<String> = self
70            .parameters()
71            .iter()
72            .filter(|(_, param)| param.required.unwrap_or(false))
73            .map(|(name, _)| name.clone())
74            .collect();
75
76        ToolDefinition {
77            tool_type: "function".to_string(),
78            function: FunctionDefinition {
79                name: self.name().to_string(),
80                description: self.description().to_string(),
81                parameters: ParametersSchema {
82                    schema_type: "object".to_string(),
83                    properties: self.parameters(),
84                    required: if required.is_empty() { None } else { Some(required) },
85                },
86            },
87        }
88    }
89}
90
91pub struct ToolRegistry {
92    tools: HashMap<String, Box<dyn Tool>>,
93}
94
95impl ToolRegistry {
96    pub fn new() -> Self {
97        Self {
98            tools: HashMap::new(),
99        }
100    }
101
102    pub fn register(&mut self, tool: Box<dyn Tool>) {
103        let name = tool.name().to_string();
104        self.tools.insert(name, tool);
105    }
106
107    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
108        self.tools.get(name).map(|b| &**b)
109    }
110
111    pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
112        let tool = self
113            .tools
114            .get(name)
115            .ok_or_else(|| HeliosError::ToolError(format!("Tool '{}' not found", name)))?;
116        
117        tool.execute(args).await
118    }
119
120    pub fn get_definitions(&self) -> Vec<ToolDefinition> {
121        self.tools
122            .values()
123            .map(|tool| tool.to_definition())
124            .collect()
125    }
126
127    pub fn list_tools(&self) -> Vec<String> {
128        self.tools.keys().cloned().collect()
129    }
130}
131
132impl Default for ToolRegistry {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138// Example built-in tools
139
140pub struct CalculatorTool;
141
142#[async_trait]
143impl Tool for CalculatorTool {
144    fn name(&self) -> &str {
145        "calculator"
146    }
147
148    fn description(&self) -> &str {
149        "Perform basic arithmetic operations. Supports +, -, *, / operations."
150    }
151
152    fn parameters(&self) -> HashMap<String, ToolParameter> {
153        let mut params = HashMap::new();
154        params.insert(
155            "expression".to_string(),
156            ToolParameter {
157                param_type: "string".to_string(),
158                description: "Mathematical expression to evaluate (e.g., '2 + 2')".to_string(),
159                required: Some(true),
160            },
161        );
162        params
163    }
164
165    async fn execute(&self, args: Value) -> Result<ToolResult> {
166        let expression = args
167            .get("expression")
168            .and_then(|v| v.as_str())
169            .ok_or_else(|| HeliosError::ToolError("Missing 'expression' parameter".to_string()))?;
170
171        // Simple expression evaluator
172        let result = evaluate_expression(expression)?;
173        Ok(ToolResult::success(result.to_string()))
174    }
175}
176
177fn evaluate_expression(expr: &str) -> Result<f64> {
178    let expr = expr.replace(" ", "");
179    
180    // Simple parsing for basic operations
181    for op in &['*', '/', '+', '-'] {
182        if let Some(pos) = expr.rfind(*op) {
183            if pos == 0 {
184                continue; // Skip if operator is at the beginning (negative number)
185            }
186            let left = &expr[..pos];
187            let right = &expr[pos + 1..];
188            
189            let left_val = evaluate_expression(left)?;
190            let right_val = evaluate_expression(right)?;
191            
192            return Ok(match op {
193                '+' => left_val + right_val,
194                '-' => left_val - right_val,
195                '*' => left_val * right_val,
196                '/' => {
197                    if right_val == 0.0 {
198                        return Err(HeliosError::ToolError("Division by zero".to_string()));
199                    }
200                    left_val / right_val
201                }
202                _ => unreachable!(),
203            });
204        }
205    }
206    
207    expr.parse::<f64>()
208        .map_err(|_| HeliosError::ToolError(format!("Invalid expression: {}", expr)))
209}
210
211pub struct EchoTool;
212
213#[async_trait]
214impl Tool for EchoTool {
215    fn name(&self) -> &str {
216        "echo"
217    }
218
219    fn description(&self) -> &str {
220        "Echo back the provided message."
221    }
222
223    fn parameters(&self) -> HashMap<String, ToolParameter> {
224        let mut params = HashMap::new();
225        params.insert(
226            "message".to_string(),
227            ToolParameter {
228                param_type: "string".to_string(),
229                description: "The message to echo back".to_string(),
230                required: Some(true),
231            },
232        );
233        params
234    }
235
236    async fn execute(&self, args: Value) -> Result<ToolResult> {
237        let message = args
238            .get("message")
239            .and_then(|v| v.as_str())
240            .ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?;
241
242        Ok(ToolResult::success(format!("Echo: {}", message)))
243    }
244}