use crate::error::{Result, ToolError};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, call: ToolCall) -> Result<ToolResult>;
fn requires_confirmation(&self) -> bool {
false
}
fn examples(&self) -> Vec<ToolExample> {
Vec::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub parameters: serde_json::Value,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub tool_call_id: String,
pub success: bool,
pub content: String,
pub data: Option<serde_json::Value>,
pub duration_ms: Option<u64>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolExample {
pub description: String,
pub parameters: serde_json::Value,
pub expected_result: String,
}
pub struct ToolExecutor {
tools: HashMap<String, Box<dyn Tool>>,
}
impl ToolCall {
pub fn new<S: Into<String>>(name: S, parameters: serde_json::Value) -> Self {
Self {
id: Uuid::new_v4().to_string(),
name: name.into(),
parameters,
metadata: None,
}
}
pub fn get_parameter<T>(&self, key: &str) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
let value = self.parameters
.get(key)
.ok_or_else(|| ToolError::InvalidParameters {
message: format!("Missing parameter: {}", key),
})?;
serde_json::from_value(value.clone())
.map_err(|_| ToolError::InvalidParameters {
message: format!("Invalid parameter type for: {}", key),
}.into())
}
pub fn get_parameter_or<T>(&self, key: &str, default: T) -> T
where
T: for<'de> Deserialize<'de> + Clone,
{
self.get_parameter(key).unwrap_or(default)
}
}
impl ToolResult {
pub fn success<S: Into<String>>(tool_call_id: S, content: S) -> Self {
Self {
tool_call_id: tool_call_id.into(),
success: true,
content: content.into(),
data: None,
duration_ms: None,
metadata: None,
}
}
pub fn error<S: Into<String>>(tool_call_id: S, error: S) -> Self {
Self {
tool_call_id: tool_call_id.into(),
success: false,
content: format!("Error: {}", error.into()),
data: None,
duration_ms: None,
metadata: None,
}
}
pub fn with_data(mut self, data: serde_json::Value) -> Self {
self.data = Some(data);
self
}
pub fn with_duration(mut self, duration_ms: u64) -> Self {
self.duration_ms = Some(duration_ms);
self
}
pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
self.metadata = Some(metadata);
self
}
}
impl ToolExecutor {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register_tool(&mut self, tool: Box<dyn Tool>) {
self.tools.insert(tool.name().to_string(), tool);
}
pub fn get_tool(&self, name: &str) -> Option<&dyn Tool> {
self.tools.get(name).map(|t| t.as_ref())
}
pub fn list_tools(&self) -> Vec<&str> {
self.tools.keys().map(|s| s.as_str()).collect()
}
pub async fn execute(&self, call: ToolCall) -> Result<ToolResult> {
let tool = self.get_tool(&call.name)
.ok_or_else(|| ToolError::NotFound {
name: call.name.clone(),
})?;
let start_time = std::time::Instant::now();
let call_id = call.id.clone();
let result = tool.execute(call).await;
let duration = start_time.elapsed().as_millis() as u64;
match result {
Ok(mut result) => {
result.duration_ms = Some(duration);
Ok(result)
}
Err(e) => Ok(ToolResult::error(&call_id, &e.to_string()).with_duration(duration)),
}
}
pub fn get_tool_definitions(&self) -> Vec<crate::llm::ToolDefinition> {
self.tools
.values()
.map(|tool| crate::llm::ToolDefinition {
tool_type: "function".to_string(),
function: crate::llm::FunctionDefinition {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters_schema(),
},
})
.collect()
}
}
impl Default for ToolExecutor {
fn default() -> Self {
Self::new()
}
}