rrag_graph/
tools.rs

1//! # Tool System for RGraph Agents
2//!
3//! This module provides the tool system that allows agents to interact with
4//! external systems, perform computations, and access data.
5
6use crate::state::GraphState;
7// Future use for tool implementations
8use async_trait::async_trait;
9use std::collections::HashMap;
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14/// Result of tool execution
15#[derive(Debug, Clone)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct ToolResult {
18    /// Tool output
19    pub output: serde_json::Value,
20
21    /// Additional metadata
22    pub metadata: HashMap<String, serde_json::Value>,
23}
24
25/// Error that can occur during tool execution
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28    #[error("Tool execution error: {message}")]
29    Execution { message: String },
30
31    #[error("Invalid arguments: {message}")]
32    InvalidArguments { message: String },
33
34    #[error("Tool not found: {name}")]
35    NotFound { name: String },
36
37    #[error("Permission denied for tool: {name}")]
38    PermissionDenied { name: String },
39
40    #[error("Tool timeout: {name}")]
41    Timeout { name: String },
42
43    #[error("Network error: {message}")]
44    Network { message: String },
45
46    #[error("Other error: {0}")]
47    Other(#[from] anyhow::Error),
48}
49
50/// Configuration for a tool
51#[derive(Debug, Clone)]
52#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53pub struct ToolConfig {
54    /// Tool name
55    pub name: String,
56
57    /// Tool description
58    pub description: String,
59
60    /// Tool version
61    pub version: String,
62
63    /// Whether the tool requires authentication
64    pub requires_auth: bool,
65
66    /// Maximum execution time in milliseconds
67    pub timeout_ms: Option<u64>,
68
69    /// Tool-specific configuration
70    pub config: serde_json::Value,
71}
72
73/// Core trait for all tools
74#[async_trait]
75pub trait Tool: Send + Sync {
76    /// Execute the tool with given arguments
77    async fn execute(
78        &self,
79        arguments: &serde_json::Value,
80        state: &GraphState,
81    ) -> Result<ToolResult, ToolError>;
82
83    /// Get the tool name
84    fn name(&self) -> &str;
85
86    /// Get the tool description
87    fn description(&self) -> &str;
88
89    /// Get the tool's argument schema (JSON Schema)
90    fn argument_schema(&self) -> serde_json::Value {
91        serde_json::json!({
92            "type": "object",
93            "properties": {},
94            "additionalProperties": true
95        })
96    }
97
98    /// Validate tool arguments
99    fn validate_arguments(&self, _arguments: &serde_json::Value) -> Result<(), ToolError> {
100        Ok(())
101    }
102
103    /// Check if the tool requires authentication
104    fn requires_auth(&self) -> bool {
105        false
106    }
107
108    /// Get tool metadata
109    fn metadata(&self) -> HashMap<String, serde_json::Value> {
110        HashMap::new()
111    }
112}
113
114/// Simple echo tool for testing
115pub struct EchoTool {
116    name: String,
117}
118
119impl EchoTool {
120    pub fn new() -> Self {
121        Self {
122            name: "echo".to_string(),
123        }
124    }
125}
126
127impl Default for EchoTool {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133#[async_trait]
134impl Tool for EchoTool {
135    async fn execute(
136        &self,
137        arguments: &serde_json::Value,
138        _state: &GraphState,
139    ) -> Result<ToolResult, ToolError> {
140        let message = arguments
141            .get("message")
142            .and_then(|v| v.as_str())
143            .unwrap_or("Hello from EchoTool!");
144
145        Ok(ToolResult {
146            output: serde_json::json!({
147                "echo": message,
148                "timestamp": chrono::Utc::now().to_rfc3339()
149            }),
150            metadata: HashMap::new(),
151        })
152    }
153
154    fn name(&self) -> &str {
155        &self.name
156    }
157
158    fn description(&self) -> &str {
159        "A simple tool that echoes back the input message"
160    }
161
162    fn argument_schema(&self) -> serde_json::Value {
163        serde_json::json!({
164            "type": "object",
165            "properties": {
166                "message": {
167                    "type": "string",
168                    "description": "The message to echo back"
169                }
170            },
171            "required": ["message"]
172        })
173    }
174}
175
176/// Calculator tool for basic arithmetic
177pub struct CalculatorTool {
178    name: String,
179}
180
181impl CalculatorTool {
182    pub fn new() -> Self {
183        Self {
184            name: "calculator".to_string(),
185        }
186    }
187}
188
189impl Default for CalculatorTool {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195#[async_trait]
196impl Tool for CalculatorTool {
197    async fn execute(
198        &self,
199        arguments: &serde_json::Value,
200        _state: &GraphState,
201    ) -> Result<ToolResult, ToolError> {
202        let operation = arguments
203            .get("operation")
204            .and_then(|v| v.as_str())
205            .ok_or_else(|| ToolError::InvalidArguments {
206                message: "Missing 'operation' field".to_string(),
207            })?;
208
209        let a = arguments.get("a").and_then(|v| v.as_f64()).ok_or_else(|| {
210            ToolError::InvalidArguments {
211                message: "Missing or invalid 'a' field".to_string(),
212            }
213        })?;
214
215        let b = arguments.get("b").and_then(|v| v.as_f64()).ok_or_else(|| {
216            ToolError::InvalidArguments {
217                message: "Missing or invalid 'b' field".to_string(),
218            }
219        })?;
220
221        let result = match operation {
222            "add" => a + b,
223            "subtract" => a - b,
224            "multiply" => a * b,
225            "divide" => {
226                if b == 0.0 {
227                    return Err(ToolError::Execution {
228                        message: "Division by zero".to_string(),
229                    });
230                }
231                a / b
232            }
233            _ => {
234                return Err(ToolError::InvalidArguments {
235                    message: format!("Unknown operation: {}", operation),
236                })
237            }
238        };
239
240        Ok(ToolResult {
241            output: serde_json::json!({
242                "operation": operation,
243                "operands": [a, b],
244                "result": result
245            }),
246            metadata: HashMap::new(),
247        })
248    }
249
250    fn name(&self) -> &str {
251        &self.name
252    }
253
254    fn description(&self) -> &str {
255        "A calculator tool for basic arithmetic operations"
256    }
257
258    fn argument_schema(&self) -> serde_json::Value {
259        serde_json::json!({
260            "type": "object",
261            "properties": {
262                "operation": {
263                    "type": "string",
264                    "enum": ["add", "subtract", "multiply", "divide"],
265                    "description": "The arithmetic operation to perform"
266                },
267                "a": {
268                    "type": "number",
269                    "description": "First operand"
270                },
271                "b": {
272                    "type": "number",
273                    "description": "Second operand"
274                }
275            },
276            "required": ["operation", "a", "b"]
277        })
278    }
279
280    fn validate_arguments(&self, arguments: &serde_json::Value) -> Result<(), ToolError> {
281        if !arguments.is_object() {
282            return Err(ToolError::InvalidArguments {
283                message: "Arguments must be an object".to_string(),
284            });
285        }
286
287        // Check required fields
288        let required_fields = ["operation", "a", "b"];
289        for field in &required_fields {
290            if !arguments.get(field).is_some() {
291                return Err(ToolError::InvalidArguments {
292                    message: format!("Missing required field: {}", field),
293                });
294            }
295        }
296
297        // Validate operation
298        if let Some(op) = arguments.get("operation").and_then(|v| v.as_str()) {
299            if !["add", "subtract", "multiply", "divide"].contains(&op) {
300                return Err(ToolError::InvalidArguments {
301                    message: format!("Invalid operation: {}", op),
302                });
303            }
304        }
305
306        Ok(())
307    }
308}
309
310/// Tool registry for managing available tools
311pub struct ToolRegistry {
312    tools: HashMap<String, Box<dyn Tool>>,
313}
314
315impl ToolRegistry {
316    /// Create a new tool registry
317    pub fn new() -> Self {
318        Self {
319            tools: HashMap::new(),
320        }
321    }
322
323    /// Register a tool
324    pub fn register(&mut self, tool: Box<dyn Tool>) {
325        let name = tool.name().to_string();
326        self.tools.insert(name, tool);
327    }
328
329    /// Get a tool by name
330    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
331        self.tools.get(name).map(|t| t.as_ref())
332    }
333
334    /// Get all available tool names
335    pub fn tool_names(&self) -> Vec<String> {
336        self.tools.keys().cloned().collect()
337    }
338
339    /// Execute a tool
340    pub async fn execute(
341        &self,
342        tool_name: &str,
343        arguments: &serde_json::Value,
344        state: &GraphState,
345    ) -> Result<ToolResult, ToolError> {
346        let tool = self.get(tool_name).ok_or_else(|| ToolError::NotFound {
347            name: tool_name.to_string(),
348        })?;
349
350        // Validate arguments
351        tool.validate_arguments(arguments)?;
352
353        // Execute tool
354        tool.execute(arguments, state).await
355    }
356}
357
358impl Default for ToolRegistry {
359    fn default() -> Self {
360        let mut registry = Self::new();
361
362        // Register default tools
363        registry.register(Box::new(EchoTool::new()));
364        registry.register(Box::new(CalculatorTool::new()));
365
366        registry
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[tokio::test]
375    async fn test_echo_tool() {
376        let tool = EchoTool::new();
377        let state = GraphState::new();
378        let arguments = serde_json::json!({
379            "message": "Hello, World!"
380        });
381
382        let result = tool.execute(&arguments, &state).await.unwrap();
383
384        assert_eq!(result.output["echo"], "Hello, World!");
385        assert!(result.output.get("timestamp").is_some());
386    }
387
388    #[tokio::test]
389    async fn test_calculator_tool() {
390        let tool = CalculatorTool::new();
391        let state = GraphState::new();
392
393        // Test addition
394        let arguments = serde_json::json!({
395            "operation": "add",
396            "a": 5.0,
397            "b": 3.0
398        });
399
400        let result = tool.execute(&arguments, &state).await.unwrap();
401        assert_eq!(result.output["result"], 8.0);
402
403        // Test division by zero
404        let arguments = serde_json::json!({
405            "operation": "divide",
406            "a": 5.0,
407            "b": 0.0
408        });
409
410        let result = tool.execute(&arguments, &state).await;
411        assert!(result.is_err());
412    }
413
414    #[tokio::test]
415    async fn test_tool_registry() {
416        let mut registry = ToolRegistry::new();
417        registry.register(Box::new(EchoTool::new()));
418
419        assert!(registry.get("echo").is_some());
420        assert!(registry.get("nonexistent").is_none());
421
422        let tool_names = registry.tool_names();
423        assert!(tool_names.contains(&"echo".to_string()));
424
425        // Test execution through registry
426        let arguments = serde_json::json!({
427            "message": "Test"
428        });
429        let state = GraphState::new();
430
431        let result = registry.execute("echo", &arguments, &state).await.unwrap();
432        assert_eq!(result.output["echo"], "Test");
433    }
434
435    #[test]
436    fn test_calculator_validation() {
437        let tool = CalculatorTool::new();
438
439        // Valid arguments
440        let valid_args = serde_json::json!({
441            "operation": "add",
442            "a": 1.0,
443            "b": 2.0
444        });
445        assert!(tool.validate_arguments(&valid_args).is_ok());
446
447        // Invalid operation
448        let invalid_args = serde_json::json!({
449            "operation": "invalid",
450            "a": 1.0,
451            "b": 2.0
452        });
453        assert!(tool.validate_arguments(&invalid_args).is_err());
454
455        // Missing field
456        let missing_field = serde_json::json!({
457            "operation": "add",
458            "a": 1.0
459        });
460        assert!(tool.validate_arguments(&missing_field).is_err());
461    }
462
463    #[test]
464    fn test_tool_schemas() {
465        let echo_tool = EchoTool::new();
466        let calc_tool = CalculatorTool::new();
467
468        let echo_schema = echo_tool.argument_schema();
469        assert_eq!(echo_schema["type"], "object");
470        assert!(echo_schema["properties"].get("message").is_some());
471
472        let calc_schema = calc_tool.argument_schema();
473        assert_eq!(calc_schema["type"], "object");
474        assert!(calc_schema["properties"].get("operation").is_some());
475        assert!(calc_schema["properties"].get("a").is_some());
476        assert!(calc_schema["properties"].get("b").is_some());
477    }
478}