use crate::state::GraphState;
use async_trait::async_trait;
use std::collections::HashMap;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ToolResult {
pub output: serde_json::Value,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, thiserror::Error)]
pub enum ToolError {
#[error("Tool execution error: {message}")]
Execution { message: String },
#[error("Invalid arguments: {message}")]
InvalidArguments { message: String },
#[error("Tool not found: {name}")]
NotFound { name: String },
#[error("Permission denied for tool: {name}")]
PermissionDenied { name: String },
#[error("Tool timeout: {name}")]
Timeout { name: String },
#[error("Network error: {message}")]
Network { message: String },
#[error("Other error: {0}")]
Other(#[from] anyhow::Error),
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ToolConfig {
pub name: String,
pub description: String,
pub version: String,
pub requires_auth: bool,
pub timeout_ms: Option<u64>,
pub config: serde_json::Value,
}
#[async_trait]
pub trait Tool: Send + Sync {
async fn execute(
&self,
arguments: &serde_json::Value,
state: &GraphState,
) -> Result<ToolResult, ToolError>;
fn name(&self) -> &str;
fn description(&self) -> &str;
fn argument_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": true
})
}
fn validate_arguments(&self, _arguments: &serde_json::Value) -> Result<(), ToolError> {
Ok(())
}
fn requires_auth(&self) -> bool {
false
}
fn metadata(&self) -> HashMap<String, serde_json::Value> {
HashMap::new()
}
}
pub struct EchoTool {
name: String,
}
impl EchoTool {
pub fn new() -> Self {
Self {
name: "echo".to_string(),
}
}
}
impl Default for EchoTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for EchoTool {
async fn execute(
&self,
arguments: &serde_json::Value,
_state: &GraphState,
) -> Result<ToolResult, ToolError> {
let message = arguments
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("Hello from EchoTool!");
Ok(ToolResult {
output: serde_json::json!({
"echo": message,
"timestamp": chrono::Utc::now().to_rfc3339()
}),
metadata: HashMap::new(),
})
}
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"A simple tool that echoes back the input message"
}
fn argument_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "The message to echo back"
}
},
"required": ["message"]
})
}
}
pub struct CalculatorTool {
name: String,
}
impl CalculatorTool {
pub fn new() -> Self {
Self {
name: "calculator".to_string(),
}
}
}
impl Default for CalculatorTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for CalculatorTool {
async fn execute(
&self,
arguments: &serde_json::Value,
_state: &GraphState,
) -> Result<ToolResult, ToolError> {
let operation = arguments
.get("operation")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidArguments {
message: "Missing 'operation' field".to_string(),
})?;
let a = arguments.get("a").and_then(|v| v.as_f64()).ok_or_else(|| {
ToolError::InvalidArguments {
message: "Missing or invalid 'a' field".to_string(),
}
})?;
let b = arguments.get("b").and_then(|v| v.as_f64()).ok_or_else(|| {
ToolError::InvalidArguments {
message: "Missing or invalid 'b' field".to_string(),
}
})?;
let result = match operation {
"add" => a + b,
"subtract" => a - b,
"multiply" => a * b,
"divide" => {
if b == 0.0 {
return Err(ToolError::Execution {
message: "Division by zero".to_string(),
});
}
a / b
}
_ => {
return Err(ToolError::InvalidArguments {
message: format!("Unknown operation: {}", operation),
})
}
};
Ok(ToolResult {
output: serde_json::json!({
"operation": operation,
"operands": [a, b],
"result": result
}),
metadata: HashMap::new(),
})
}
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"A calculator tool for basic arithmetic operations"
}
fn argument_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"a": {
"type": "number",
"description": "First operand"
},
"b": {
"type": "number",
"description": "Second operand"
}
},
"required": ["operation", "a", "b"]
})
}
fn validate_arguments(&self, arguments: &serde_json::Value) -> Result<(), ToolError> {
if !arguments.is_object() {
return Err(ToolError::InvalidArguments {
message: "Arguments must be an object".to_string(),
});
}
let required_fields = ["operation", "a", "b"];
for field in &required_fields {
if !arguments.get(field).is_some() {
return Err(ToolError::InvalidArguments {
message: format!("Missing required field: {}", field),
});
}
}
if let Some(op) = arguments.get("operation").and_then(|v| v.as_str()) {
if !["add", "subtract", "multiply", "divide"].contains(&op) {
return Err(ToolError::InvalidArguments {
message: format!("Invalid operation: {}", op),
});
}
}
Ok(())
}
}
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: Box<dyn Tool>) {
let name = tool.name().to_string();
self.tools.insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
self.tools.get(name).map(|t| t.as_ref())
}
pub fn tool_names(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
pub async fn execute(
&self,
tool_name: &str,
arguments: &serde_json::Value,
state: &GraphState,
) -> Result<ToolResult, ToolError> {
let tool = self.get(tool_name).ok_or_else(|| ToolError::NotFound {
name: tool_name.to_string(),
})?;
tool.validate_arguments(arguments)?;
tool.execute(arguments, state).await
}
}
impl Default for ToolRegistry {
fn default() -> Self {
let mut registry = Self::new();
registry.register(Box::new(EchoTool::new()));
registry.register(Box::new(CalculatorTool::new()));
registry
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_echo_tool() {
let tool = EchoTool::new();
let state = GraphState::new();
let arguments = serde_json::json!({
"message": "Hello, World!"
});
let result = tool.execute(&arguments, &state).await.unwrap();
assert_eq!(result.output["echo"], "Hello, World!");
assert!(result.output.get("timestamp").is_some());
}
#[tokio::test]
async fn test_calculator_tool() {
let tool = CalculatorTool::new();
let state = GraphState::new();
let arguments = serde_json::json!({
"operation": "add",
"a": 5.0,
"b": 3.0
});
let result = tool.execute(&arguments, &state).await.unwrap();
assert_eq!(result.output["result"], 8.0);
let arguments = serde_json::json!({
"operation": "divide",
"a": 5.0,
"b": 0.0
});
let result = tool.execute(&arguments, &state).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(EchoTool::new()));
assert!(registry.get("echo").is_some());
assert!(registry.get("nonexistent").is_none());
let tool_names = registry.tool_names();
assert!(tool_names.contains(&"echo".to_string()));
let arguments = serde_json::json!({
"message": "Test"
});
let state = GraphState::new();
let result = registry.execute("echo", &arguments, &state).await.unwrap();
assert_eq!(result.output["echo"], "Test");
}
#[test]
fn test_calculator_validation() {
let tool = CalculatorTool::new();
let valid_args = serde_json::json!({
"operation": "add",
"a": 1.0,
"b": 2.0
});
assert!(tool.validate_arguments(&valid_args).is_ok());
let invalid_args = serde_json::json!({
"operation": "invalid",
"a": 1.0,
"b": 2.0
});
assert!(tool.validate_arguments(&invalid_args).is_err());
let missing_field = serde_json::json!({
"operation": "add",
"a": 1.0
});
assert!(tool.validate_arguments(&missing_field).is_err());
}
#[test]
fn test_tool_schemas() {
let echo_tool = EchoTool::new();
let calc_tool = CalculatorTool::new();
let echo_schema = echo_tool.argument_schema();
assert_eq!(echo_schema["type"], "object");
assert!(echo_schema["properties"].get("message").is_some());
let calc_schema = calc_tool.argument_schema();
assert_eq!(calc_schema["type"], "object");
assert!(calc_schema["properties"].get("operation").is_some());
assert!(calc_schema["properties"].get("a").is_some());
assert!(calc_schema["properties"].get("b").is_some());
}
}