use echo_core::error::{Result, ToolError};
use echo_core::llm::types::ToolDefinition;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::Semaphore;
pub use echo_core::tools::{Tool, ToolExecutionConfig, ToolParameters, ToolResult};
pub struct ToolManager {
tools: HashMap<String, Box<dyn Tool>>,
config: ToolExecutionConfig,
semaphore: Option<Arc<Semaphore>>,
cached_definitions: RwLock<Option<Vec<ToolDefinition>>>,
}
impl ToolManager {
pub fn get_openai_tools(&self) -> Vec<ToolDefinition> {
if let Some(ref cached) = *self.cached_definitions.read().unwrap() {
return cached.clone();
}
let definitions: Vec<ToolDefinition> = self
.tools
.values()
.map(|tool| ToolDefinition::from_tool(&**tool))
.collect();
*self.cached_definitions.write().unwrap() = Some(definitions.clone());
definitions
}
fn invalidate_cache(&self) {
*self.cached_definitions.write().unwrap() = None;
}
}
impl Default for ToolManager {
fn default() -> Self {
Self::new()
}
}
impl ToolManager {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
semaphore: None,
config: ToolExecutionConfig::default(),
cached_definitions: RwLock::new(None),
}
}
pub fn new_with_config(config: ToolExecutionConfig) -> Self {
let semaphore = config
.max_concurrency
.map(|n| Arc::new(Semaphore::new(n.max(1))));
Self {
tools: HashMap::new(),
semaphore,
config,
cached_definitions: RwLock::new(None),
}
}
pub fn max_concurrency(&self) -> Option<usize> {
self.config.max_concurrency
}
pub fn register(&mut self, tool: Box<dyn Tool>) {
self.tools.insert(tool.name().to_string(), tool);
self.invalidate_cache();
}
pub fn register_tools(&mut self, tools: Vec<Box<dyn Tool>>) {
for tool in tools {
self.tools.insert(tool.name().to_string(), tool);
}
self.invalidate_cache();
}
pub fn unregister(&mut self, tool_name: &str) -> Option<Box<dyn Tool>> {
let tool = self.tools.remove(tool_name);
if tool.is_some() {
self.invalidate_cache();
}
tool
}
pub fn list_tools(&self) -> Vec<&str> {
self.tools.keys().map(|name| name.as_str()).collect()
}
pub fn get_tool(&self, tool_name: &str) -> Option<&dyn Tool> {
self.tools.get(tool_name).map(|tool| &**tool)
}
pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|tool| ToolDefinition::from_tool(&**tool))
.collect()
}
pub async fn execute_tool(
&self,
tool_name: &str,
parameters: ToolParameters,
) -> Result<ToolResult> {
let tool = self
.get_tool(tool_name)
.ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
let _permit = if let Some(sem) = &self.semaphore {
match sem.acquire().await {
Ok(permit) => Some(permit),
Err(e) => {
tracing::warn!("Failed to acquire semaphore permit: {}", e);
return Err(ToolError::ExecutionFailed {
tool: tool_name.to_string(),
message: format!("Concurrency limit error: {}", e),
}
.into());
}
}
} else {
None
};
let max_retries = if self.config.retry_on_fail {
self.config.max_retries
} else {
0
};
let mut last_err: Option<echo_core::error::ReactError> = None;
for attempt in 0..=max_retries {
if attempt > 0 {
let delay_ms = self.config.retry_delay_ms * (1u64 << (attempt as u64 - 1).min(5));
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
let result = if self.config.timeout_ms > 0 {
match tokio::time::timeout(
Duration::from_millis(self.config.timeout_ms),
tool.execute(parameters.clone()),
)
.await
{
Ok(r) => r,
Err(_) => Err(ToolError::Timeout(tool_name.to_string()).into()),
}
} else {
tool.execute(parameters.clone()).await
};
match result {
Ok(r) => return Ok(r),
Err(e) if attempt < max_retries => {
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.unwrap_or_else(|| ToolError::NotFound(tool_name.to_string()).into()))
}
pub fn validate_tool_parameters(
&self,
tool_name: &str,
parameters: &ToolParameters,
) -> Result<()> {
let tool = self
.get_tool(tool_name)
.ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
futures::executor::block_on(tool.validate_parameters(parameters))
}
}