1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use crate::error::{HeliosError, Result};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ToolParameter {
9 #[serde(rename = "type")]
10 pub param_type: String,
11 pub description: String,
12 #[serde(skip)]
13 pub required: Option<bool>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ToolDefinition {
18 #[serde(rename = "type")]
19 pub tool_type: String,
20 pub function: FunctionDefinition,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct FunctionDefinition {
25 pub name: String,
26 pub description: String,
27 pub parameters: ParametersSchema,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ParametersSchema {
32 #[serde(rename = "type")]
33 pub schema_type: String,
34 pub properties: HashMap<String, ToolParameter>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub required: Option<Vec<String>>,
37}
38
39#[derive(Debug, Clone)]
40pub struct ToolResult {
41 pub success: bool,
42 pub output: String,
43}
44
45impl ToolResult {
46 pub fn success(output: impl Into<String>) -> Self {
47 Self {
48 success: true,
49 output: output.into(),
50 }
51 }
52
53 pub fn error(message: impl Into<String>) -> Self {
54 Self {
55 success: false,
56 output: message.into(),
57 }
58 }
59}
60
61#[async_trait]
62pub trait Tool: Send + Sync {
63 fn name(&self) -> &str;
64 fn description(&self) -> &str;
65 fn parameters(&self) -> HashMap<String, ToolParameter>;
66 async fn execute(&self, args: Value) -> Result<ToolResult>;
67
68 fn to_definition(&self) -> ToolDefinition {
69 let required: Vec<String> = self
70 .parameters()
71 .iter()
72 .filter(|(_, param)| param.required.unwrap_or(false))
73 .map(|(name, _)| name.clone())
74 .collect();
75
76 ToolDefinition {
77 tool_type: "function".to_string(),
78 function: FunctionDefinition {
79 name: self.name().to_string(),
80 description: self.description().to_string(),
81 parameters: ParametersSchema {
82 schema_type: "object".to_string(),
83 properties: self.parameters(),
84 required: if required.is_empty() { None } else { Some(required) },
85 },
86 },
87 }
88 }
89}
90
91pub struct ToolRegistry {
92 tools: HashMap<String, Box<dyn Tool>>,
93}
94
95impl ToolRegistry {
96 pub fn new() -> Self {
97 Self {
98 tools: HashMap::new(),
99 }
100 }
101
102 pub fn register(&mut self, tool: Box<dyn Tool>) {
103 let name = tool.name().to_string();
104 self.tools.insert(name, tool);
105 }
106
107 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
108 self.tools.get(name).map(|b| &**b)
109 }
110
111 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
112 let tool = self
113 .tools
114 .get(name)
115 .ok_or_else(|| HeliosError::ToolError(format!("Tool '{}' not found", name)))?;
116
117 tool.execute(args).await
118 }
119
120 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
121 self.tools
122 .values()
123 .map(|tool| tool.to_definition())
124 .collect()
125 }
126
127 pub fn list_tools(&self) -> Vec<String> {
128 self.tools.keys().cloned().collect()
129 }
130}
131
132impl Default for ToolRegistry {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138pub struct CalculatorTool;
141
142#[async_trait]
143impl Tool for CalculatorTool {
144 fn name(&self) -> &str {
145 "calculator"
146 }
147
148 fn description(&self) -> &str {
149 "Perform basic arithmetic operations. Supports +, -, *, / operations."
150 }
151
152 fn parameters(&self) -> HashMap<String, ToolParameter> {
153 let mut params = HashMap::new();
154 params.insert(
155 "expression".to_string(),
156 ToolParameter {
157 param_type: "string".to_string(),
158 description: "Mathematical expression to evaluate (e.g., '2 + 2')".to_string(),
159 required: Some(true),
160 },
161 );
162 params
163 }
164
165 async fn execute(&self, args: Value) -> Result<ToolResult> {
166 let expression = args
167 .get("expression")
168 .and_then(|v| v.as_str())
169 .ok_or_else(|| HeliosError::ToolError("Missing 'expression' parameter".to_string()))?;
170
171 let result = evaluate_expression(expression)?;
173 Ok(ToolResult::success(result.to_string()))
174 }
175}
176
177fn evaluate_expression(expr: &str) -> Result<f64> {
178 let expr = expr.replace(" ", "");
179
180 for op in &['*', '/', '+', '-'] {
182 if let Some(pos) = expr.rfind(*op) {
183 if pos == 0 {
184 continue; }
186 let left = &expr[..pos];
187 let right = &expr[pos + 1..];
188
189 let left_val = evaluate_expression(left)?;
190 let right_val = evaluate_expression(right)?;
191
192 return Ok(match op {
193 '+' => left_val + right_val,
194 '-' => left_val - right_val,
195 '*' => left_val * right_val,
196 '/' => {
197 if right_val == 0.0 {
198 return Err(HeliosError::ToolError("Division by zero".to_string()));
199 }
200 left_val / right_val
201 }
202 _ => unreachable!(),
203 });
204 }
205 }
206
207 expr.parse::<f64>()
208 .map_err(|_| HeliosError::ToolError(format!("Invalid expression: {}", expr)))
209}
210
211pub struct EchoTool;
212
213#[async_trait]
214impl Tool for EchoTool {
215 fn name(&self) -> &str {
216 "echo"
217 }
218
219 fn description(&self) -> &str {
220 "Echo back the provided message."
221 }
222
223 fn parameters(&self) -> HashMap<String, ToolParameter> {
224 let mut params = HashMap::new();
225 params.insert(
226 "message".to_string(),
227 ToolParameter {
228 param_type: "string".to_string(),
229 description: "The message to echo back".to_string(),
230 required: Some(true),
231 },
232 );
233 params
234 }
235
236 async fn execute(&self, args: Value) -> Result<ToolResult> {
237 let message = args
238 .get("message")
239 .and_then(|v| v.as_str())
240 .ok_or_else(|| HeliosError::ToolError("Missing 'message' parameter".to_string()))?;
241
242 Ok(ToolResult::success(format!("Echo: {}", message)))
243 }
244}