use std::time::Instant;
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use crate::tool::{Capability, ToolRegistry};
use crate::tool_error::ToolError;
use crate::tool_result::ToolResult;
pub struct ToolExecutor {
registry: ToolRegistry,
audit_enabled: bool,
max_parallel: usize,
allowed_capabilities: Vec<Capability>,
}
impl ToolExecutor {
pub fn new(registry: ToolRegistry) -> Self {
Self {
registry,
audit_enabled: true,
max_parallel: 0, allowed_capabilities: vec![
Capability::PureComputation,
Capability::Network,
Capability::FileSystem,
Capability::Subprocess,
Capability::Environment,
Capability::Cryptography,
],
}
}
pub fn without_audit(registry: ToolRegistry) -> Self {
Self {
registry,
audit_enabled: false,
max_parallel: 0,
allowed_capabilities: vec![
Capability::PureComputation,
Capability::Network,
Capability::FileSystem,
Capability::Subprocess,
Capability::Environment,
Capability::Cryptography,
],
}
}
pub fn with_max_parallel(mut self, max: usize) -> Self {
self.max_parallel = max;
self
}
pub fn with_allowed_capabilities(mut self, caps: Vec<Capability>) -> Self {
self.allowed_capabilities = caps;
self
}
pub fn register_wasm_tool(
&mut self,
definition: crate::tool::ToolDefinition,
module_bytes: Vec<u8>,
capabilities: Vec<Capability>,
) {
use crate::wasm_tool::WasmTool;
use std::sync::Arc;
let tool = WasmTool::new(definition, module_bytes, capabilities);
self.registry.register(Arc::new(tool));
}
pub fn register_wasm_tool_from_file(
&mut self,
definition: crate::tool::ToolDefinition,
path: impl AsRef<std::path::Path>,
capabilities: Vec<Capability>,
) -> Result<(), std::io::Error> {
let bytes = std::fs::read(path)?;
self.register_wasm_tool(definition, bytes, capabilities);
Ok(())
}
pub async fn execute(
&self,
tool_name: &str,
args: serde_json::Value,
) -> Result<ToolResult, ToolError> {
let tool = self.registry.get(tool_name).ok_or_else(|| {
warn!(tool = tool_name, "Tool not found");
ToolError::not_found(tool_name)
})?;
for cap in tool.capabilities() {
if !self.allowed_capabilities.contains(&cap) {
warn!(
tool = tool_name,
capability = ?cap,
"Tool requires missing capability"
);
return Err(ToolError::unavailable(
tool_name,
format!("Sandbox violation: tool requires {:?}", cap),
));
}
}
if !tool.is_available() {
warn!(tool = tool_name, "Tool is unavailable");
return Err(ToolError::unavailable(
tool_name,
"Tool is currently disabled",
));
}
debug!(tool = tool_name, "Validating arguments against schema");
let schema_str = tool.definition().parameters;
if !schema_str.is_empty() && schema_str != "{}" {
let schema_json: serde_json::Value = serde_json::from_str(schema_str).map_err(|e| {
ToolError::execution_failed(tool_name, format!("Invalid tool schema: {}", e))
})?;
let compiled = jsonschema::JSONSchema::compile(&schema_json).map_err(|e| {
ToolError::execution_failed(
tool_name,
format!("Failed to compile tool schema: {}", e),
)
})?;
if !compiled.is_valid(&args) {
warn!(tool = tool_name, "Schema validation failed");
return Err(ToolError::invalid_args(
tool_name,
"Arguments do not match tool schema",
));
}
}
debug!(tool = tool_name, "Running custom validation");
tool.validate(&args)?;
let tool_timeout = tool.timeout();
let start = Instant::now();
debug!(
tool = tool_name,
timeout_ms = tool_timeout.as_millis(),
"Executing tool"
);
let output = timeout(tool_timeout, tool.execute(args.clone()))
.await
.map_err(|_| {
error!(
tool = tool_name,
timeout_ms = tool_timeout.as_millis(),
"Tool execution timed out"
);
ToolError::timeout(tool_name, tool_timeout.as_millis() as u64)
})??;
let elapsed = start.elapsed();
let result = ToolResult::new(tool_name, &args, output, elapsed);
info!(
tool = tool_name,
execution_ms = elapsed.as_millis(),
hash = %result.hash,
"Tool executed successfully"
);
if self.audit_enabled {
debug!(
tool = tool_name,
result_hash = %result.hash,
"Audit entry created"
);
}
Ok(result)
}
pub async fn execute_parallel(
&self,
calls: Vec<(String, serde_json::Value)>,
) -> Vec<Result<ToolResult, ToolError>> {
debug!(count = calls.len(), "Executing tools in parallel");
if self.max_parallel > 0 {
use futures::stream::{self, StreamExt};
stream::iter(calls)
.map(|(name, args)| async move { self.execute(&name, args).await })
.buffered(self.max_parallel)
.collect()
.await
} else {
let futures: Vec<_> = calls
.into_iter()
.map(|(name, args)| {
async move { self.execute(&name, args).await }
})
.collect();
futures::future::join_all(futures).await
}
}
pub fn registry(&self) -> &ToolRegistry {
&self.registry
}
pub fn registry_mut(&mut self) -> &mut ToolRegistry {
&mut self.registry
}
pub fn has_tool(&self, name: &str) -> bool {
self.registry.contains(name)
}
pub fn tool_names(&self) -> Vec<&str> {
self.registry.names()
}
}
impl std::fmt::Debug for ToolExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolExecutor")
.field("tools", &self.registry.names())
.field("audit_enabled", &self.audit_enabled)
.field("max_parallel", &self.max_parallel)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool::{Tool, ToolDefinition};
use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
struct EchoTool {
definition: ToolDefinition,
}
impl EchoTool {
fn new() -> Self {
Self {
definition: ToolDefinition::new(
"echo",
"Echo back the input",
r#"{"type": "object"}"#,
),
}
}
}
#[async_trait]
impl Tool for EchoTool {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
Ok(serde_json::json!({ "echo": args }))
}
}
struct FailTool {
definition: ToolDefinition,
}
impl FailTool {
fn new() -> Self {
Self {
definition: ToolDefinition::new("fail", "Always fails", r#"{"type": "object"}"#),
}
}
}
#[async_trait]
impl Tool for FailTool {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
Err(ToolError::execution_failed("fail", "Intentional failure"))
}
}
struct SlowTool {
definition: ToolDefinition,
}
impl SlowTool {
fn new() -> Self {
Self {
definition: ToolDefinition::new("slow", "Takes forever", r#"{"type": "object"}"#),
}
}
}
#[async_trait]
impl Tool for SlowTool {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
fn timeout(&self) -> Duration {
Duration::from_millis(50) }
async fn execute(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(serde_json::json!({"done": true}))
}
}
#[tokio::test]
async fn test_execute_success() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(EchoTool::new()));
let executor = ToolExecutor::new(registry);
let result = executor
.execute("echo", serde_json::json!({"message": "hello"}))
.await
.unwrap();
assert_eq!(result.tool_name, "echo");
assert!(result.output["echo"]["message"] == "hello");
assert!(!result.hash.to_string().is_empty());
}
#[tokio::test]
async fn test_execute_not_found() {
let registry = ToolRegistry::new();
let executor = ToolExecutor::new(registry);
let result = executor.execute("nonexistent", serde_json::json!({})).await;
assert!(matches!(result, Err(ToolError::NotFound { .. })));
}
#[tokio::test]
async fn test_execute_failure() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(FailTool::new()));
let executor = ToolExecutor::new(registry);
let result = executor.execute("fail", serde_json::json!({})).await;
assert!(matches!(result, Err(ToolError::ExecutionFailed { .. })));
}
#[tokio::test]
async fn test_execute_timeout() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(SlowTool::new()));
let executor = ToolExecutor::new(registry);
let result = executor.execute("slow", serde_json::json!({})).await;
assert!(matches!(result, Err(ToolError::Timeout { .. })));
}
#[tokio::test]
async fn test_execute_parallel() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(EchoTool::new()));
let executor = ToolExecutor::new(registry);
let calls = vec![
("echo".to_string(), serde_json::json!({"n": 1})),
("echo".to_string(), serde_json::json!({"n": 2})),
("echo".to_string(), serde_json::json!({"n": 3})),
];
let results = executor.execute_parallel(calls).await;
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.is_ok()));
}
#[tokio::test]
async fn test_has_tool() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(EchoTool::new()));
let executor = ToolExecutor::new(registry);
assert!(executor.has_tool("echo"));
assert!(!executor.has_tool("nonexistent"));
}
}