Skip to main content

echo_execution/
tools.rs

1//! Tool system core — `ToolManager` and tool trait re-exports.
2//!
3//! The [`ToolManager`] handles registration, execution, concurrency control,
4//! and timeout/retry for all tools in an agent session.
5
6use echo_core::error::{Result, ToolError};
7use echo_core::llm::types::ToolDefinition;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11use tokio::sync::Semaphore;
12
13pub use echo_core::tools::{Tool, ToolExecutionConfig, ToolParameters, ToolRegistrar, ToolResult};
14
15impl ToolRegistrar for ToolManager {
16    fn register(&mut self, tool: Box<dyn Tool>) {
17        self.register(tool);
18    }
19}
20
21/// 工具管理器
22///
23/// 负责工具的注册、执行、并发控制和超时重试。
24pub struct ToolManager {
25    tools: HashMap<String, Box<dyn Tool>>,
26    config: ToolExecutionConfig,
27    /// 并发限流器
28    semaphore: Option<Arc<Semaphore>>,
29    /// 缓存的工具定义
30    cached_definitions: RwLock<Option<Vec<ToolDefinition>>>,
31}
32
33impl ToolManager {
34    /// 获取 OpenAI 格式的工具定义列表(带缓存)
35    ///
36    /// 首次调用时构建并缓存,后续直接返回缓存值。
37    /// 注册新工具后缓存会自动失效。
38    pub fn get_openai_tools(&self) -> Vec<ToolDefinition> {
39        // Fast path: read cached
40        if let Some(ref cached) = *self.cached_definitions.read().unwrap() {
41            return cached.clone();
42        }
43        // Build + cache
44        let definitions: Vec<ToolDefinition> = self
45            .tools
46            .values()
47            .map(|tool| ToolDefinition::from_tool(&**tool))
48            .collect();
49        *self.cached_definitions.write().unwrap() = Some(definitions.clone());
50        definitions
51    }
52
53    /// 使缓存失效(注册/注销工具时调用)
54    fn invalidate_cache(&self) {
55        *self.cached_definitions.write().unwrap() = None;
56    }
57}
58
59impl Default for ToolManager {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl ToolManager {
66    pub fn new() -> Self {
67        Self {
68            tools: HashMap::new(),
69            semaphore: None,
70            config: ToolExecutionConfig::default(),
71            cached_definitions: RwLock::new(None),
72        }
73    }
74
75    pub fn new_with_config(config: ToolExecutionConfig) -> Self {
76        let semaphore = config
77            .max_concurrency
78            .map(|n| Arc::new(Semaphore::new(n.max(1))));
79        Self {
80            tools: HashMap::new(),
81            semaphore,
82            config,
83            cached_definitions: RwLock::new(None),
84        }
85    }
86
87    /// 返回并发度限制(`None` = 不限制)
88    pub fn max_concurrency(&self) -> Option<usize> {
89        self.config.max_concurrency
90    }
91
92    /// 注册单个工具
93    pub fn register(&mut self, tool: Box<dyn Tool>) {
94        self.tools.insert(tool.name().to_string(), tool);
95        self.invalidate_cache();
96    }
97
98    /// 批量注册工具
99    pub fn register_tools(&mut self, tools: Vec<Box<dyn Tool>>) {
100        for tool in tools {
101            self.tools.insert(tool.name().to_string(), tool);
102        }
103        self.invalidate_cache();
104    }
105
106    /// 注销工具
107    pub fn unregister(&mut self, tool_name: &str) -> Option<Box<dyn Tool>> {
108        let tool = self.tools.remove(tool_name);
109        if tool.is_some() {
110            self.invalidate_cache();
111        }
112        tool
113    }
114
115    /// 列出所有已注册的工具名称
116    pub fn list_tools(&self) -> Vec<&str> {
117        self.tools.keys().map(|name| name.as_str()).collect()
118    }
119
120    /// 获取工具引用
121    pub fn get_tool(&self, tool_name: &str) -> Option<&dyn Tool> {
122        self.tools.get(tool_name).map(|tool| &**tool)
123    }
124
125    /// 获取工具定义列表(用于展示或调试)
126    pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
127        self.tools
128            .values()
129            .map(|tool| ToolDefinition::from_tool(&**tool))
130            .collect()
131    }
132
133    /// 执行工具
134    ///
135    /// 支持并发控制、超时和重试。
136    pub async fn execute_tool(
137        &self,
138        tool_name: &str,
139        parameters: ToolParameters,
140    ) -> Result<ToolResult> {
141        let tool = self
142            .get_tool(tool_name)
143            .ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
144
145        // 并发控制:获取信号量许可
146        let _permit = if let Some(sem) = &self.semaphore {
147            match sem.acquire().await {
148                Ok(permit) => Some(permit),
149                Err(e) => {
150                    tracing::warn!("Failed to acquire semaphore permit: {}", e);
151                    return Err(ToolError::ExecutionFailed {
152                        tool: tool_name.to_string(),
153                        message: format!("Concurrency limit error: {}", e),
154                    }
155                    .into());
156                }
157            }
158        } else {
159            None
160        };
161
162        let max_retries = if self.config.retry_on_fail {
163            self.config.max_retries
164        } else {
165            0
166        };
167
168        let mut last_err: Option<echo_core::error::ReactError> = None;
169
170        for attempt in 0..=max_retries {
171            if attempt > 0 {
172                let delay_ms = self.config.retry_delay_ms * (1u64 << (attempt as u64 - 1).min(5));
173                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
174            }
175
176            let result = if self.config.timeout_ms > 0 {
177                match tokio::time::timeout(
178                    Duration::from_millis(self.config.timeout_ms),
179                    tool.execute(parameters.clone()),
180                )
181                .await
182                {
183                    Ok(r) => r,
184                    Err(_) => Err(ToolError::Timeout(tool_name.to_string()).into()),
185                }
186            } else {
187                tool.execute(parameters.clone()).await
188            };
189
190            match result {
191                Ok(r) => return Ok(r),
192                Err(e) if attempt < max_retries => {
193                    last_err = Some(e);
194                }
195                Err(e) => return Err(e),
196            }
197        }
198
199        Err(last_err.unwrap_or_else(|| ToolError::NotFound(tool_name.to_string()).into()))
200    }
201
202    /// 验证工具参数
203    pub fn validate_tool_parameters(
204        &self,
205        tool_name: &str,
206        parameters: &ToolParameters,
207    ) -> Result<()> {
208        let tool = self
209            .get_tool(tool_name)
210            .ok_or_else(|| ToolError::NotFound(tool_name.to_string()))?;
211        futures::executor::block_on(tool.validate_parameters(parameters))
212    }
213}