1use crate::state::GraphState;
7use async_trait::async_trait;
9use std::collections::HashMap;
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct ToolResult {
18 pub output: serde_json::Value,
20
21 pub metadata: HashMap<String, serde_json::Value>,
23}
24
25#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28 #[error("Tool execution error: {message}")]
29 Execution { message: String },
30
31 #[error("Invalid arguments: {message}")]
32 InvalidArguments { message: String },
33
34 #[error("Tool not found: {name}")]
35 NotFound { name: String },
36
37 #[error("Permission denied for tool: {name}")]
38 PermissionDenied { name: String },
39
40 #[error("Tool timeout: {name}")]
41 Timeout { name: String },
42
43 #[error("Network error: {message}")]
44 Network { message: String },
45
46 #[error("Other error: {0}")]
47 Other(#[from] anyhow::Error),
48}
49
50#[derive(Debug, Clone)]
52#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53pub struct ToolConfig {
54 pub name: String,
56
57 pub description: String,
59
60 pub version: String,
62
63 pub requires_auth: bool,
65
66 pub timeout_ms: Option<u64>,
68
69 pub config: serde_json::Value,
71}
72
73#[async_trait]
75pub trait Tool: Send + Sync {
76 async fn execute(
78 &self,
79 arguments: &serde_json::Value,
80 state: &GraphState,
81 ) -> Result<ToolResult, ToolError>;
82
83 fn name(&self) -> &str;
85
86 fn description(&self) -> &str;
88
89 fn argument_schema(&self) -> serde_json::Value {
91 serde_json::json!({
92 "type": "object",
93 "properties": {},
94 "additionalProperties": true
95 })
96 }
97
98 fn validate_arguments(&self, _arguments: &serde_json::Value) -> Result<(), ToolError> {
100 Ok(())
101 }
102
103 fn requires_auth(&self) -> bool {
105 false
106 }
107
108 fn metadata(&self) -> HashMap<String, serde_json::Value> {
110 HashMap::new()
111 }
112}
113
114pub struct EchoTool {
116 name: String,
117}
118
119impl EchoTool {
120 pub fn new() -> Self {
121 Self {
122 name: "echo".to_string(),
123 }
124 }
125}
126
127impl Default for EchoTool {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133#[async_trait]
134impl Tool for EchoTool {
135 async fn execute(
136 &self,
137 arguments: &serde_json::Value,
138 _state: &GraphState,
139 ) -> Result<ToolResult, ToolError> {
140 let message = arguments
141 .get("message")
142 .and_then(|v| v.as_str())
143 .unwrap_or("Hello from EchoTool!");
144
145 Ok(ToolResult {
146 output: serde_json::json!({
147 "echo": message,
148 "timestamp": chrono::Utc::now().to_rfc3339()
149 }),
150 metadata: HashMap::new(),
151 })
152 }
153
154 fn name(&self) -> &str {
155 &self.name
156 }
157
158 fn description(&self) -> &str {
159 "A simple tool that echoes back the input message"
160 }
161
162 fn argument_schema(&self) -> serde_json::Value {
163 serde_json::json!({
164 "type": "object",
165 "properties": {
166 "message": {
167 "type": "string",
168 "description": "The message to echo back"
169 }
170 },
171 "required": ["message"]
172 })
173 }
174}
175
176pub struct CalculatorTool {
178 name: String,
179}
180
181impl CalculatorTool {
182 pub fn new() -> Self {
183 Self {
184 name: "calculator".to_string(),
185 }
186 }
187}
188
189impl Default for CalculatorTool {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[async_trait]
196impl Tool for CalculatorTool {
197 async fn execute(
198 &self,
199 arguments: &serde_json::Value,
200 _state: &GraphState,
201 ) -> Result<ToolResult, ToolError> {
202 let operation = arguments
203 .get("operation")
204 .and_then(|v| v.as_str())
205 .ok_or_else(|| ToolError::InvalidArguments {
206 message: "Missing 'operation' field".to_string(),
207 })?;
208
209 let a = arguments.get("a").and_then(|v| v.as_f64()).ok_or_else(|| {
210 ToolError::InvalidArguments {
211 message: "Missing or invalid 'a' field".to_string(),
212 }
213 })?;
214
215 let b = arguments.get("b").and_then(|v| v.as_f64()).ok_or_else(|| {
216 ToolError::InvalidArguments {
217 message: "Missing or invalid 'b' field".to_string(),
218 }
219 })?;
220
221 let result = match operation {
222 "add" => a + b,
223 "subtract" => a - b,
224 "multiply" => a * b,
225 "divide" => {
226 if b == 0.0 {
227 return Err(ToolError::Execution {
228 message: "Division by zero".to_string(),
229 });
230 }
231 a / b
232 }
233 _ => {
234 return Err(ToolError::InvalidArguments {
235 message: format!("Unknown operation: {}", operation),
236 })
237 }
238 };
239
240 Ok(ToolResult {
241 output: serde_json::json!({
242 "operation": operation,
243 "operands": [a, b],
244 "result": result
245 }),
246 metadata: HashMap::new(),
247 })
248 }
249
250 fn name(&self) -> &str {
251 &self.name
252 }
253
254 fn description(&self) -> &str {
255 "A calculator tool for basic arithmetic operations"
256 }
257
258 fn argument_schema(&self) -> serde_json::Value {
259 serde_json::json!({
260 "type": "object",
261 "properties": {
262 "operation": {
263 "type": "string",
264 "enum": ["add", "subtract", "multiply", "divide"],
265 "description": "The arithmetic operation to perform"
266 },
267 "a": {
268 "type": "number",
269 "description": "First operand"
270 },
271 "b": {
272 "type": "number",
273 "description": "Second operand"
274 }
275 },
276 "required": ["operation", "a", "b"]
277 })
278 }
279
280 fn validate_arguments(&self, arguments: &serde_json::Value) -> Result<(), ToolError> {
281 if !arguments.is_object() {
282 return Err(ToolError::InvalidArguments {
283 message: "Arguments must be an object".to_string(),
284 });
285 }
286
287 let required_fields = ["operation", "a", "b"];
289 for field in &required_fields {
290 if !arguments.get(field).is_some() {
291 return Err(ToolError::InvalidArguments {
292 message: format!("Missing required field: {}", field),
293 });
294 }
295 }
296
297 if let Some(op) = arguments.get("operation").and_then(|v| v.as_str()) {
299 if !["add", "subtract", "multiply", "divide"].contains(&op) {
300 return Err(ToolError::InvalidArguments {
301 message: format!("Invalid operation: {}", op),
302 });
303 }
304 }
305
306 Ok(())
307 }
308}
309
310pub struct ToolRegistry {
312 tools: HashMap<String, Box<dyn Tool>>,
313}
314
315impl ToolRegistry {
316 pub fn new() -> Self {
318 Self {
319 tools: HashMap::new(),
320 }
321 }
322
323 pub fn register(&mut self, tool: Box<dyn Tool>) {
325 let name = tool.name().to_string();
326 self.tools.insert(name, tool);
327 }
328
329 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
331 self.tools.get(name).map(|t| t.as_ref())
332 }
333
334 pub fn tool_names(&self) -> Vec<String> {
336 self.tools.keys().cloned().collect()
337 }
338
339 pub async fn execute(
341 &self,
342 tool_name: &str,
343 arguments: &serde_json::Value,
344 state: &GraphState,
345 ) -> Result<ToolResult, ToolError> {
346 let tool = self.get(tool_name).ok_or_else(|| ToolError::NotFound {
347 name: tool_name.to_string(),
348 })?;
349
350 tool.validate_arguments(arguments)?;
352
353 tool.execute(arguments, state).await
355 }
356}
357
358impl Default for ToolRegistry {
359 fn default() -> Self {
360 let mut registry = Self::new();
361
362 registry.register(Box::new(EchoTool::new()));
364 registry.register(Box::new(CalculatorTool::new()));
365
366 registry
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[tokio::test]
375 async fn test_echo_tool() {
376 let tool = EchoTool::new();
377 let state = GraphState::new();
378 let arguments = serde_json::json!({
379 "message": "Hello, World!"
380 });
381
382 let result = tool.execute(&arguments, &state).await.unwrap();
383
384 assert_eq!(result.output["echo"], "Hello, World!");
385 assert!(result.output.get("timestamp").is_some());
386 }
387
388 #[tokio::test]
389 async fn test_calculator_tool() {
390 let tool = CalculatorTool::new();
391 let state = GraphState::new();
392
393 let arguments = serde_json::json!({
395 "operation": "add",
396 "a": 5.0,
397 "b": 3.0
398 });
399
400 let result = tool.execute(&arguments, &state).await.unwrap();
401 assert_eq!(result.output["result"], 8.0);
402
403 let arguments = serde_json::json!({
405 "operation": "divide",
406 "a": 5.0,
407 "b": 0.0
408 });
409
410 let result = tool.execute(&arguments, &state).await;
411 assert!(result.is_err());
412 }
413
414 #[tokio::test]
415 async fn test_tool_registry() {
416 let mut registry = ToolRegistry::new();
417 registry.register(Box::new(EchoTool::new()));
418
419 assert!(registry.get("echo").is_some());
420 assert!(registry.get("nonexistent").is_none());
421
422 let tool_names = registry.tool_names();
423 assert!(tool_names.contains(&"echo".to_string()));
424
425 let arguments = serde_json::json!({
427 "message": "Test"
428 });
429 let state = GraphState::new();
430
431 let result = registry.execute("echo", &arguments, &state).await.unwrap();
432 assert_eq!(result.output["echo"], "Test");
433 }
434
435 #[test]
436 fn test_calculator_validation() {
437 let tool = CalculatorTool::new();
438
439 let valid_args = serde_json::json!({
441 "operation": "add",
442 "a": 1.0,
443 "b": 2.0
444 });
445 assert!(tool.validate_arguments(&valid_args).is_ok());
446
447 let invalid_args = serde_json::json!({
449 "operation": "invalid",
450 "a": 1.0,
451 "b": 2.0
452 });
453 assert!(tool.validate_arguments(&invalid_args).is_err());
454
455 let missing_field = serde_json::json!({
457 "operation": "add",
458 "a": 1.0
459 });
460 assert!(tool.validate_arguments(&missing_field).is_err());
461 }
462
463 #[test]
464 fn test_tool_schemas() {
465 let echo_tool = EchoTool::new();
466 let calc_tool = CalculatorTool::new();
467
468 let echo_schema = echo_tool.argument_schema();
469 assert_eq!(echo_schema["type"], "object");
470 assert!(echo_schema["properties"].get("message").is_some());
471
472 let calc_schema = calc_tool.argument_schema();
473 assert_eq!(calc_schema["type"], "object");
474 assert!(calc_schema["properties"].get("operation").is_some());
475 assert!(calc_schema["properties"].get("a").is_some());
476 assert!(calc_schema["properties"].get("b").is_some());
477 }
478}