use std::time::Instant;
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use crate::tool::ToolRegistry;
use crate::tool_error::ToolError;
use crate::tool_result::ToolResult;
pub struct ToolExecutor {
registry: ToolRegistry,
audit_enabled: bool,
max_parallel: usize,
}
impl ToolExecutor {
pub fn new(registry: ToolRegistry) -> Self {
Self {
registry,
audit_enabled: true,
max_parallel: 0, }
}
pub fn without_audit(registry: ToolRegistry) -> Self {
Self {
registry,
audit_enabled: false,
max_parallel: 0,
}
}
pub fn with_max_parallel(mut self, max: usize) -> Self {
self.max_parallel = max;
self
}
pub async fn execute(
&self,
tool_name: &str,
args: serde_json::Value,
) -> Result<ToolResult, ToolError> {
let tool = self.registry.get(tool_name).ok_or_else(|| {
warn!(tool = tool_name, "Tool not found");
ToolError::not_found(tool_name)
})?;
if !tool.is_available() {
warn!(tool = tool_name, "Tool is unavailable");
return Err(ToolError::unavailable(
tool_name,
"Tool is currently disabled",
));
}
debug!(tool = tool_name, "Validating arguments");
tool.validate(&args)?;
let tool_timeout = tool.timeout();
let start = Instant::now();
debug!(
tool = tool_name,
timeout_ms = tool_timeout.as_millis(),
"Executing tool"
);
let output = timeout(tool_timeout, tool.execute(args.clone()))
.await
.map_err(|_| {
error!(
tool = tool_name,
timeout_ms = tool_timeout.as_millis(),
"Tool execution timed out"
);
ToolError::timeout(tool_name, tool_timeout.as_millis() as u64)
})??;
let elapsed = start.elapsed();
let result = ToolResult::new(tool_name, &args, output, elapsed);
info!(
tool = tool_name,
execution_ms = elapsed.as_millis(),
hash = %result.hash,
"Tool executed successfully"
);
if self.audit_enabled {
debug!(
tool = tool_name,
result_hash = %result.hash,
"Audit entry created"
);
}
Ok(result)
}
pub async fn execute_parallel(
&self,
calls: Vec<(String, serde_json::Value)>,
) -> Vec<Result<ToolResult, ToolError>> {
debug!(count = calls.len(), "Executing tools in parallel");
let futures: Vec<_> = calls
.into_iter()
.map(|(name, args)| {
async move { self.execute(&name, args).await }
})
.collect();
futures::future::join_all(futures).await
}
pub fn registry(&self) -> &ToolRegistry {
&self.registry
}
pub fn registry_mut(&mut self) -> &mut ToolRegistry {
&mut self.registry
}
pub fn has_tool(&self, name: &str) -> bool {
self.registry.contains(name)
}
pub fn tool_names(&self) -> Vec<&str> {
self.registry.names()
}
}
impl std::fmt::Debug for ToolExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolExecutor")
.field("tools", &self.registry.names())
.field("audit_enabled", &self.audit_enabled)
.field("max_parallel", &self.max_parallel)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool::{Tool, ToolDefinition};
use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
struct EchoTool {
definition: ToolDefinition,
}
impl EchoTool {
fn new() -> Self {
Self {
definition: ToolDefinition::new(
"echo",
"Echo back the input",
r#"{"type": "object"}"#,
),
}
}
}
#[async_trait]
impl Tool for EchoTool {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
Ok(serde_json::json!({ "echo": args }))
}
}
struct FailTool {
definition: ToolDefinition,
}
impl FailTool {
fn new() -> Self {
Self {
definition: ToolDefinition::new("fail", "Always fails", r#"{"type": "object"}"#),
}
}
}
#[async_trait]
impl Tool for FailTool {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
Err(ToolError::execution_failed("fail", "Intentional failure"))
}
}
struct SlowTool {
definition: ToolDefinition,
}
impl SlowTool {
fn new() -> Self {
Self {
definition: ToolDefinition::new("slow", "Takes forever", r#"{"type": "object"}"#),
}
}
}
#[async_trait]
impl Tool for SlowTool {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
fn timeout(&self) -> Duration {
Duration::from_millis(50) }
async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(serde_json::json!({"done": true}))
}
}
#[tokio::test]
async fn test_execute_success() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(EchoTool::new()));
let executor = ToolExecutor::new(registry);
let result = executor
.execute("echo", serde_json::json!({"message": "hello"}))
.await
.unwrap();
assert_eq!(result.tool_name, "echo");
assert!(result.output["echo"]["message"] == "hello");
assert!(!result.hash.to_string().is_empty());
}
#[tokio::test]
async fn test_execute_not_found() {
let registry = ToolRegistry::new();
let executor = ToolExecutor::new(registry);
let result = executor.execute("nonexistent", serde_json::json!({})).await;
assert!(matches!(result, Err(ToolError::NotFound { .. })));
}
#[tokio::test]
async fn test_execute_failure() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(FailTool::new()));
let executor = ToolExecutor::new(registry);
let result = executor.execute("fail", serde_json::json!({})).await;
assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
}
#[tokio::test]
async fn test_execute_timeout() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(SlowTool::new()));
let executor = ToolExecutor::new(registry);
let result = executor.execute("slow", serde_json::json!({})).await;
assert!(matches!(result, Err(ToolError::Timeout { .. })));
}
#[tokio::test]
async fn test_execute_parallel() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(EchoTool::new()));
let executor = ToolExecutor::new(registry);
let calls = vec![
("echo".to_string(), serde_json::json!({"n": 1})),
("echo".to_string(), serde_json::json!({"n": 2})),
("echo".to_string(), serde_json::json!({"n": 3})),
];
let results = executor.execute_parallel(calls).await;
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.is_ok()));
}
#[tokio::test]
async fn test_has_tool() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(EchoTool::new()));
let executor = ToolExecutor::new(registry);
assert!(executor.has_tool("echo"));
assert!(!executor.has_tool("nonexistent"));
}
}