use crate::error::{AgentError, Result};
use crate::types::ToolId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::timeout;
use tracing::{debug, info, trace, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParameterSchema {
pub param_type: String,
pub required: bool,
pub description: String,
pub default: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub output: serde_json::Value,
pub error: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[async_trait::async_trait]
pub trait Tool: Send + Sync {
fn id(&self) -> &ToolId;
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> &HashMap<String, ParameterSchema>;
async fn execute(&self, parameters: HashMap<String, serde_json::Value>) -> Result<ToolResult>;
fn validate_parameters(&self, parameters: &HashMap<String, serde_json::Value>) -> Result<()> {
trace!(tool_name = %self.name(), "Validating tool parameters");
let schema = self.parameters();
for (param_name, param_schema) in schema {
if param_schema.required && !parameters.contains_key(param_name) {
warn!(
tool_name = %self.name(),
param_name = %param_name,
"Missing required parameter"
);
return Err(AgentError::InvalidToolParameters {
tool_name: self.name().to_string(),
reason: format!("Missing required parameter: {}", param_name),
});
}
}
for (param_name, value) in parameters {
if let Some(param_schema) = schema.get(param_name) {
if !validate_type(value, ¶m_schema.param_type) {
warn!(
tool_name = %self.name(),
param_name = %param_name,
expected_type = %param_schema.param_type,
"Parameter type mismatch"
);
return Err(AgentError::InvalidToolParameters {
tool_name: self.name().to_string(),
reason: format!(
"Parameter '{}' has wrong type, expected {}",
param_name, param_schema.param_type
),
});
}
}
}
debug!(
tool_name = %self.name(),
param_count = parameters.len(),
"Parameter validation successful"
);
Ok(())
}
fn apply_defaults(&self, parameters: &mut HashMap<String, serde_json::Value>) {
let schema = self.parameters();
for (param_name, param_schema) in schema {
if !parameters.contains_key(param_name) {
if let Some(ref default_value) = param_schema.default {
trace!(
tool_name = %self.name(),
param_name = %param_name,
"Applying default parameter value"
);
parameters.insert(param_name.clone(), default_value.clone());
}
}
}
}
}
fn validate_type(value: &serde_json::Value, expected_type: &str) -> bool {
use serde_json::Value;
match expected_type {
"string" => matches!(value, Value::String(_)),
"number" => matches!(value, Value::Number(_)),
"boolean" => matches!(value, Value::Bool(_)),
"object" => matches!(value, Value::Object(_)),
"array" => matches!(value, Value::Array(_)),
"null" => matches!(value, Value::Null),
_ => true, }
}
pub struct ToolRegistry {
tools: Arc<RwLock<HashMap<ToolId, Arc<dyn Tool>>>>,
tools_by_name: Arc<RwLock<HashMap<String, ToolId>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
info!("Creating new tool registry");
Self {
tools: Arc::new(RwLock::new(HashMap::new())),
tools_by_name: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register(&self, tool: Box<dyn Tool>) -> Result<ToolId> {
let tool_id = *tool.id();
let tool_name = tool.name().to_string();
info!(
tool_id = %tool_id,
tool_name = %tool_name,
"Registering tool"
);
let mut tools = self.tools.write().await;
let mut tools_by_name = self.tools_by_name.write().await;
if tools_by_name.contains_key(&tool_name) {
warn!(
tool_name = %tool_name,
"Attempted to register duplicate tool"
);
return Err(AgentError::ToolAlreadyRegistered(tool_name));
}
tools.insert(tool_id, Arc::from(tool));
tools_by_name.insert(tool_name.clone(), tool_id);
debug!(
tool_id = %tool_id,
tool_name = %tool_name,
total_tools = tools.len(),
"Tool registered successfully"
);
Ok(tool_id)
}
pub async fn unregister(&self, tool_id: &ToolId) -> Result<()> {
info!(tool_id = %tool_id, "Unregistering tool");
let mut tools = self.tools.write().await;
let mut tools_by_name = self.tools_by_name.write().await;
if let Some(tool) = tools.remove(tool_id) {
let tool_name = tool.name().to_string();
tools_by_name.remove(&tool_name);
debug!(
tool_id = %tool_id,
tool_name = %tool_name,
remaining_tools = tools.len(),
"Tool unregistered successfully"
);
Ok(())
} else {
warn!(tool_id = %tool_id, "Attempted to unregister unknown tool");
Err(AgentError::ToolNotFound(*tool_id))
}
}
pub async fn get(&self, tool_id: &ToolId) -> Option<Arc<dyn Tool>> {
let tools = self.tools.read().await;
tools.get(tool_id).cloned()
}
pub async fn get_by_name(&self, name: &str) -> Option<Arc<dyn Tool>> {
let tools_by_name = self.tools_by_name.read().await;
let tool_id = tools_by_name.get(name)?;
let tools = self.tools.read().await;
tools.get(tool_id).cloned()
}
pub async fn list(&self) -> Vec<Arc<dyn Tool>> {
let tools = self.tools.read().await;
tools.values().cloned().collect()
}
pub async fn execute(
&self,
tool_id: &ToolId,
mut parameters: HashMap<String, serde_json::Value>,
) -> Result<ToolResult> {
info!(
tool_id = %tool_id,
param_count = parameters.len(),
"Executing tool"
);
let tool = self
.get(tool_id)
.await
.ok_or_else(|| AgentError::ToolNotFound(*tool_id))?;
tool.apply_defaults(&mut parameters);
tool.validate_parameters(¶meters)?;
let result = tool.execute(parameters).await?;
debug!(
tool_id = %tool_id,
tool_name = %tool.name(),
has_error = result.error.is_some(),
"Tool execution completed"
);
Ok(result)
}
pub async fn execute_with_timeout(
&self,
tool_id: &ToolId,
parameters: HashMap<String, serde_json::Value>,
timeout_duration: Duration,
) -> Result<ToolResult> {
info!(
tool_id = %tool_id,
timeout_secs = timeout_duration.as_secs(),
"Executing tool with timeout"
);
match timeout(timeout_duration, self.execute(tool_id, parameters)).await {
Ok(result) => result,
Err(_) => {
warn!(
tool_id = %tool_id,
timeout_secs = timeout_duration.as_secs(),
"Tool execution timed out"
);
let tool = self.get(tool_id).await;
let tool_name = tool
.map(|t| t.name().to_string())
.unwrap_or_else(|| "unknown".to_string());
Err(AgentError::ToolTimeout {
tool_name,
timeout: timeout_duration,
})
}
}
}
pub async fn execute_with_retry(
&self,
tool_id: &ToolId,
parameters: HashMap<String, serde_json::Value>,
timeout_duration: Duration,
max_retries: u32,
base_backoff_ms: u64,
) -> Result<ToolResult> {
let mut attempts = 0;
let mut last_error = None;
while attempts <= max_retries {
info!(
tool_id = %tool_id,
attempt = attempts + 1,
max_attempts = max_retries + 1,
"Attempting tool execution"
);
match self
.execute_with_timeout(tool_id, parameters.clone(), timeout_duration)
.await
{
Ok(result) => {
if result.error.is_none() {
debug!(
tool_id = %tool_id,
attempts = attempts + 1,
"Tool execution successful"
);
return Ok(result);
} else {
warn!(
tool_id = %tool_id,
attempt = attempts + 1,
error = %result.error.as_ref().unwrap(),
"Tool returned error result"
);
last_error = Some(AgentError::ToolExecutionFailed {
tool_name: self
.get(tool_id)
.await
.map(|t| t.name().to_string())
.unwrap_or_else(|| "unknown".to_string()),
reason: result.error.unwrap(),
});
}
}
Err(e) => {
warn!(
tool_id = %tool_id,
attempt = attempts + 1,
error = %e,
"Tool execution failed"
);
last_error = Some(e);
}
}
attempts += 1;
if attempts <= max_retries {
let backoff_ms = base_backoff_ms * 2u64.pow(attempts - 1);
let backoff = Duration::from_millis(backoff_ms);
debug!(
tool_id = %tool_id,
backoff_ms = backoff_ms,
"Waiting before retry"
);
tokio::time::sleep(backoff).await;
}
}
warn!(
tool_id = %tool_id,
total_attempts = attempts,
"All retry attempts exhausted"
);
Err(
last_error.unwrap_or_else(|| AgentError::ToolExecutionFailed {
tool_name: "unknown".to_string(),
reason: "All retry attempts failed".to_string(),
}),
)
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestTool {
id: ToolId,
parameters: HashMap<String, ParameterSchema>,
}
impl TestTool {
fn new() -> Self {
let mut parameters = HashMap::new();
parameters.insert(
"message".to_string(),
ParameterSchema {
param_type: "string".to_string(),
required: true,
description: "A test message".to_string(),
default: None,
},
);
Self {
id: ToolId::new(),
parameters,
}
}
}
#[async_trait::async_trait]
impl Tool for TestTool {
fn id(&self) -> &ToolId {
&self.id
}
fn name(&self) -> &str {
"test"
}
fn description(&self) -> &str {
"A test tool"
}
fn parameters(&self) -> &HashMap<String, ParameterSchema> {
&self.parameters
}
async fn execute(
&self,
parameters: HashMap<String, serde_json::Value>,
) -> Result<ToolResult> {
Ok(ToolResult {
output: serde_json::to_value(parameters).unwrap(),
error: None,
metadata: HashMap::new(),
})
}
}
#[tokio::test]
async fn test_parameter_validation_missing_required() {
let tool = TestTool::new();
let params = HashMap::new();
let result = tool.validate_parameters(¶ms);
assert!(result.is_err());
}
#[tokio::test]
async fn test_parameter_validation_type_mismatch() {
let tool = TestTool::new();
let mut params = HashMap::new();
params.insert("message".to_string(), serde_json::json!(123));
let result = tool.validate_parameters(¶ms);
assert!(result.is_err());
}
#[tokio::test]
async fn test_parameter_validation_success() {
let tool = TestTool::new();
let mut params = HashMap::new();
params.insert("message".to_string(), serde_json::json!("Hello"));
let result = tool.validate_parameters(¶ms);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_type_validation() {
assert!(validate_type(&serde_json::json!("hello"), "string"));
assert!(validate_type(&serde_json::json!(123), "number"));
assert!(validate_type(&serde_json::json!(true), "boolean"));
assert!(validate_type(&serde_json::json!({}), "object"));
assert!(validate_type(&serde_json::json!([]), "array"));
assert!(validate_type(&serde_json::json!(null), "null"));
assert!(!validate_type(&serde_json::json!(123), "string"));
assert!(!validate_type(&serde_json::json!("hello"), "number"));
}
struct SlowTool {
id: ToolId,
delay: Duration,
parameters: HashMap<String, ParameterSchema>,
}
impl SlowTool {
fn new_with_delay(delay: Duration) -> Self {
Self {
id: ToolId::new(),
delay,
parameters: HashMap::new(),
}
}
}
#[async_trait::async_trait]
impl Tool for SlowTool {
fn id(&self) -> &ToolId {
&self.id
}
fn name(&self) -> &str {
"slow"
}
fn description(&self) -> &str {
"A slow tool for timeout testing"
}
fn parameters(&self) -> &HashMap<String, ParameterSchema> {
&self.parameters
}
async fn execute(
&self,
_parameters: HashMap<String, serde_json::Value>,
) -> Result<ToolResult> {
tokio::time::sleep(self.delay).await;
Ok(ToolResult {
output: serde_json::json!({ "result": "slow execution completed" }),
error: None,
metadata: HashMap::new(),
})
}
}
struct FlakyTool {
id: ToolId,
failure_count: Arc<tokio::sync::Mutex<u32>>,
fail_until: u32,
parameters: HashMap<String, ParameterSchema>,
}
impl FlakyTool {
fn new_with_failures(fail_until: u32) -> Self {
Self {
id: ToolId::new(),
failure_count: Arc::new(tokio::sync::Mutex::new(0)),
fail_until,
parameters: HashMap::new(),
}
}
}
#[async_trait::async_trait]
impl Tool for FlakyTool {
fn id(&self) -> &ToolId {
&self.id
}
fn name(&self) -> &str {
"flaky"
}
fn description(&self) -> &str {
"A flaky tool for retry testing"
}
fn parameters(&self) -> &HashMap<String, ParameterSchema> {
&self.parameters
}
async fn execute(
&self,
_parameters: HashMap<String, serde_json::Value>,
) -> Result<ToolResult> {
let mut count = self.failure_count.lock().await;
*count += 1;
if *count <= self.fail_until {
Ok(ToolResult {
output: serde_json::json!({}),
error: Some(format!("Simulated failure #{}", *count)),
metadata: HashMap::new(),
})
} else {
Ok(ToolResult {
output: serde_json::json!({ "result": "success after retries" }),
error: None,
metadata: HashMap::new(),
})
}
}
}
#[tokio::test]
async fn test_tool_timeout() {
let registry = ToolRegistry::new();
let slow_tool = SlowTool::new_with_delay(Duration::from_secs(2));
let tool_id = slow_tool.id;
registry.register(Box::new(slow_tool)).await.unwrap();
let result = registry
.execute_with_timeout(&tool_id, HashMap::new(), Duration::from_secs(1))
.await;
assert!(result.is_err());
match result.unwrap_err() {
AgentError::ToolTimeout { tool_name, timeout } => {
assert_eq!(tool_name, "slow");
assert_eq!(timeout, Duration::from_secs(1));
}
_ => panic!("Expected ToolTimeout error"),
}
}
#[tokio::test]
async fn test_tool_timeout_success() {
let registry = ToolRegistry::new();
let slow_tool = SlowTool::new_with_delay(Duration::from_secs(1));
let tool_id = slow_tool.id;
registry.register(Box::new(slow_tool)).await.unwrap();
let result = registry
.execute_with_timeout(&tool_id, HashMap::new(), Duration::from_secs(2))
.await;
assert!(result.is_ok());
let tool_result = result.unwrap();
assert!(tool_result.error.is_none());
}
#[tokio::test]
async fn test_tool_retry_success_first_attempt() {
let registry = ToolRegistry::new();
let fast_tool = SlowTool::new_with_delay(Duration::from_millis(10));
let tool_id = fast_tool.id;
registry.register(Box::new(fast_tool)).await.unwrap();
let result = registry
.execute_with_retry(
&tool_id,
HashMap::new(),
Duration::from_secs(1),
3, 100, )
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_tool_retry_success_after_failures() {
let registry = ToolRegistry::new();
let flaky_tool = FlakyTool::new_with_failures(2);
let tool_id = flaky_tool.id;
registry.register(Box::new(flaky_tool)).await.unwrap();
let result = registry
.execute_with_retry(
&tool_id,
HashMap::new(),
Duration::from_secs(1),
3, 50, )
.await;
assert!(result.is_ok());
let tool_result = result.unwrap();
assert!(tool_result.error.is_none());
}
#[tokio::test]
async fn test_tool_retry_all_attempts_fail() {
let registry = ToolRegistry::new();
let flaky_tool = FlakyTool::new_with_failures(10); let tool_id = flaky_tool.id;
registry.register(Box::new(flaky_tool)).await.unwrap();
let result = registry
.execute_with_retry(
&tool_id,
HashMap::new(),
Duration::from_secs(1),
2, 50, )
.await;
assert!(result.is_err());
match result.unwrap_err() {
AgentError::ToolExecutionFailed { tool_name, reason } => {
assert_eq!(tool_name, "flaky");
assert!(reason.contains("Simulated failure"));
}
_ => panic!("Expected ToolExecutionFailed error"),
}
}
#[tokio::test]
async fn test_tool_retry_exponential_backoff() {
let registry = ToolRegistry::new();
let flaky_tool = FlakyTool::new_with_failures(2);
let tool_id = flaky_tool.id;
registry.register(Box::new(flaky_tool)).await.unwrap();
let start = std::time::Instant::now();
registry
.execute_with_retry(
&tool_id,
HashMap::new(),
Duration::from_secs(1),
2, 100, )
.await
.unwrap();
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(300));
}
}