1use echo_core::error::{Result, ToolError};
7use echo_core::llm::types::ToolDefinition;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11use tokio::sync::Semaphore;
12
13pub use echo_core::tools::{Tool, ToolExecutionConfig, ToolParameters, ToolRegistrar, ToolResult};
14
15impl ToolRegistrar for ToolManager {
16 fn register(&mut self, tool: Box<dyn Tool>) {
17 self.register(tool);
18 }
19}
20
21pub struct ToolManager {
25 tools: HashMap<String, Box<dyn Tool>>,
26 config: ToolExecutionConfig,
27 semaphore: Option<Arc<Semaphore>>,
29 cached_definitions: RwLock<Option<Vec<ToolDefinition>>>,
31}
32
33impl ToolManager {
34 pub fn get_openai_tools(&self) -> Vec<ToolDefinition> {
39 if let Some(ref cached) = *self.cached_definitions.read().unwrap() {
41 return cached.clone();
42 }
43 let definitions: Vec<ToolDefinition> = self
45 .tools
46 .values()
47 .map(|tool| ToolDefinition::from_tool(&**tool))
48 .collect();
49 *self.cached_definitions.write().unwrap() = Some(definitions.clone());
50 definitions
51 }
52
53 fn invalidate_cache(&self) {
55 *self.cached_definitions.write().unwrap() = None;
56 }
57}
58
59impl Default for ToolManager {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl ToolManager {
66 pub fn new() -> Self {
67 Self {
68 tools: HashMap::new(),
69 semaphore: None,
70 config: ToolExecutionConfig::default(),
71 cached_definitions: RwLock::new(None),
72 }
73 }
74
75 pub fn new_with_config(config: ToolExecutionConfig) -> Self {
76 let semaphore = config
77 .max_concurrency
78 .map(|n| Arc::new(Semaphore::new(n.max(1))));
79 Self {
80 tools: HashMap::new(),
81 semaphore,
82 config,
83 cached_definitions: RwLock::new(None),
84 }
85 }
86
87 pub fn max_concurrency(&self) -> Option<usize> {
89 self.config.max_concurrency
90 }
91
92 pub fn register(&mut self, tool: Box<dyn Tool>) {
94 self.tools.insert(tool.name().to_string(), tool);
95 self.invalidate_cache();
96 }
97
98 pub fn register_tools(&mut self, tools: Vec<Box<dyn Tool>>) {
100 for tool in tools {
101 self.tools.insert(tool.name().to_string(), tool);
102 }
103 self.invalidate_cache();
104 }
105
106 pub fn unregister(&mut self, tool_name: &str) -> Option<Box<dyn Tool>> {
108 let tool = self.tools.remove(tool_name);
109 if tool.is_some() {
110 self.invalidate_cache();
111 }
112 tool
113 }
114
115 pub fn list_tools(&self) -> Vec<&str> {
117 self.tools.keys().map(|name| name.as_str()).collect()
118 }
119
120 pub fn get_tool(&self, tool_name: &str) -> Option<&dyn Tool> {
122 self.tools.get(tool_name).map(|tool| &**tool)
123 }
124
125 pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
127 self.tools
128 .values()
129 .map(|tool| ToolDefinition::from_tool(&**tool))
130 .collect()
131 }
132
133 pub async fn execute_tool(
137 &self,
138 tool_name: &str,
139 parameters: ToolParameters,
140 ) -> Result<ToolResult> {
141 let tool = self
142 .get_tool(tool_name)
143 .ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
144
145 let _permit = if let Some(sem) = &self.semaphore {
147 match sem.acquire().await {
148 Ok(permit) => Some(permit),
149 Err(e) => {
150 tracing::warn!("Failed to acquire semaphore permit: {}", e);
151 return Err(ToolError::ExecutionFailed {
152 tool: tool_name.to_string(),
153 message: format!("Concurrency limit error: {}", e),
154 }
155 .into());
156 }
157 }
158 } else {
159 None
160 };
161
162 let max_retries = if self.config.retry_on_fail {
163 self.config.max_retries
164 } else {
165 0
166 };
167
168 let mut last_err: Option<echo_core::error::ReactError> = None;
169
170 for attempt in 0..=max_retries {
171 if attempt > 0 {
172 let delay_ms = self.config.retry_delay_ms * (1u64 << (attempt as u64 - 1).min(5));
173 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
174 }
175
176 let result = if self.config.timeout_ms > 0 {
177 match tokio::time::timeout(
178 Duration::from_millis(self.config.timeout_ms),
179 tool.execute(parameters.clone()),
180 )
181 .await
182 {
183 Ok(r) => r,
184 Err(_) => Err(ToolError::Timeout(tool_name.to_string()).into()),
185 }
186 } else {
187 tool.execute(parameters.clone()).await
188 };
189
190 match result {
191 Ok(r) => return Ok(r),
192 Err(e) if attempt < max_retries => {
193 last_err = Some(e);
194 }
195 Err(e) => return Err(e),
196 }
197 }
198
199 Err(last_err.unwrap_or_else(|| ToolError::NotFound(tool_name.to_string()).into()))
200 }
201
202 pub fn validate_tool_parameters(
204 &self,
205 tool_name: &str,
206 parameters: &ToolParameters,
207 ) -> Result<()> {
208 let tool = self
209 .get_tool(tool_name)
210 .ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
211 futures::executor::block_on(tool.validate_parameters(parameters))
212 }
213}