Skip to main content

lellm_agent/runtime/tools/
executor.rs

1//! 工具执行器 — 注册、分派、批量执行、并行安全分级。
2//!
3//! 通过 `ToolCatalog` 消费工具快照,不持有工具所有权。
4
5use std::borrow::Cow;
6use std::sync::Arc;
7
8use lellm_core::{Message, ToolCall, ToolError, ToolErrorKind, ToolResult};
9
10use super::super::event::AgentEvent;
11use super::super::retry::RetryPolicy;
12use super::{ToolCatalog, ToolFn, ToolSnapshot};
13use tokio::sync::mpsc::Sender;
14
15/// 工具安全分级
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum ParallelSafety {
18    Safe,
19    CategoryExclusive,
20    Exclusive,
21}
22
23/// 工具类别 — 用于 CategoryExclusive 的分组
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct ToolCategory(pub Cow<'static, str>);
26
27impl ToolCategory {
28    pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
29    pub const NETWORK: Self = Self(Cow::Borrowed("network"));
30    pub const DATABASE: Self = Self(Cow::Borrowed("database"));
31
32    pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
33        Self(name.into())
34    }
35}
36
37/// 工具注册信息 — Schema、安全分级、执行函数合一。
38///
39/// 用户通过 `ToolRegistration::safe()` 等工厂方法构造。
40/// 字段 `pub(crate)` — 外部通过工厂方法访问,内部通过 `ToolSnapshot` 消费。
41#[derive(Clone)]
42pub struct ToolRegistration {
43    pub(crate) definition: lellm_core::ToolDefinition,
44    pub(crate) safety: ParallelSafety,
45    pub(crate) category: Option<ToolCategory>,
46    pub(crate) func: ToolFn,
47}
48
49impl ToolRegistration {
50    /// 获取工具定义的引用。
51    pub fn definition(&self) -> &lellm_core::ToolDefinition {
52        &self.definition
53    }
54
55    /// 获取并行安全级别。
56    pub fn safety(&self) -> &ParallelSafety {
57        &self.safety
58    }
59
60    /// 获取工具类别(如果有)。
61    pub fn category(&self) -> Option<&ToolCategory> {
62        self.category.as_ref()
63    }
64
65    pub fn safe<F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
66    where
67        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
68        Fut: std::future::Future<Output = ToolResult> + Send + 'static,
69    {
70        Self {
71            definition: def,
72            safety: ParallelSafety::Safe,
73            category: None,
74            func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
75        }
76    }
77
78    /// 强类型便捷构造 — 自动反序列化参数。
79    ///
80    /// 与 `safe()` 的区别:闭包接收反序列化后的 `T`,而非原始 `serde_json::Value`。
81    /// 反序列化失败时返回 `ToolErrorKind::InvalidInput`。
82    pub fn safe_fn<T, F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
83    where
84        T: for<'de> serde::Deserialize<'de> + Send + 'static,
85        F: Fn(T) -> Fut + Send + Sync + 'static,
86        Fut: std::future::Future<Output = ToolResult> + Send + 'static,
87    {
88        let f = Arc::new(f);
89        Self::safe(def, move |value| {
90            let f = Arc::clone(&f);
91            let result = serde_json::from_value::<T>(value.clone());
92            Box::pin(async move {
93                match result {
94                    Ok(parsed) => f(parsed).await,
95                    Err(e) => Err(ToolError::invalid_input(format!(
96                        "invalid tool arguments: {e}"
97                    ))),
98                }
99            })
100        })
101    }
102
103    pub fn category_exclusive<F, Fut>(
104        def: lellm_core::ToolDefinition,
105        category: ToolCategory,
106        f: F,
107    ) -> Self
108    where
109        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
110        Fut: std::future::Future<Output = ToolResult> + Send + 'static,
111    {
112        Self {
113            definition: def,
114            safety: ParallelSafety::CategoryExclusive,
115            category: Some(category),
116            func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
117        }
118    }
119
120    pub fn exclusive<F, Fut>(def: lellm_core::ToolDefinition, f: F) -> Self
121    where
122        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
123        Fut: std::future::Future<Output = ToolResult> + Send + 'static,
124    {
125        Self {
126            definition: def,
127            safety: ParallelSafety::Exclusive,
128            category: None,
129            func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
130        }
131    }
132}
133
134/// 批量执行结果 — 长度、顺序、完整性三重保证。
135///
136/// **不变量:**
137/// 1. `results.len() == calls.len()` 永远成立
138/// 2. `results[i]` 对应 `calls[i]` 的执行结果(原始顺序)
139/// 3. panic 永远被转换成 `ToolResult(is_error: true)`,不会丢失
140/// 4. `panicked` 仅作为观测信号,不改变结果完整性
141#[derive(Debug)]
142pub struct BatchExecutionResult {
143    /// 按原始调用顺序排列的工具结果,长度等于输入 calls 长度。
144    pub results: Vec<Message>,
145    /// 是否有任意 spawned task panic(仅作为观测信号)
146    pub panicked: bool,
147}
148
149/// 工具执行器 — 按名称分派 ToolCall 到实际工具函数。
150///
151/// 内部持有 `ToolCatalog`,通过 `snapshot()` 获取冻结工具快照。
152/// Clone 为 O(1)(Arc 浅拷贝)。
153#[derive(Clone)]
154pub struct ToolExecutor {
155    catalog: Arc<dyn ToolCatalog>,
156    retry_policy: RetryPolicy,
157}
158
159impl ToolExecutor {
160    /// 绑定工具目录。
161    pub fn new(catalog: Arc<dyn ToolCatalog>) -> Self {
162        Self {
163            catalog,
164            retry_policy: RetryPolicy::default(),
165        }
166    }
167
168    /// 绑定工具目录,使用默认重试策略。
169    pub fn with_catalog(catalog: Arc<dyn ToolCatalog>) -> Self {
170        Self::new(catalog)
171    }
172
173    /// 构造时绑定全局重试策略。
174    pub fn with_retry_policy(catalog: Arc<dyn ToolCatalog>, policy: RetryPolicy) -> Self {
175        Self {
176            catalog,
177            retry_policy: policy,
178        }
179    }
180
181    /// 设置/替换重试策略。
182    pub fn set_retry_policy(&mut self, policy: RetryPolicy) {
183        self.retry_policy = policy;
184    }
185
186    /// 获取当前重试策略的克隆。
187    pub fn retry_policy(&self) -> RetryPolicy {
188        self.retry_policy.clone()
189    }
190
191    /// 获取冻结工具快照。
192    ///
193    /// 每轮迭代调用一次,固定本轮工具集。
194    pub async fn snapshot(&self) -> Arc<ToolSnapshot> {
195        self.catalog.snapshot().await
196    }
197
198    /// 执行单个工具调用,自带重试。
199    ///
200    /// 使用预解析的快照执行。
201    pub async fn execute_with_snapshot(
202        &self,
203        call: &ToolCall,
204        snapshot: &ToolSnapshot,
205    ) -> ToolResult {
206        match snapshot.get(&call.name) {
207            Some(entry) => {
208                self.retry_policy
209                    .execute_with_retry(&entry.func, &call.arguments)
210                    .await
211            }
212            None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
213        }
214    }
215
216    /// 执行单个工具调用,自带重试 + Retry 事件发射。
217    pub async fn execute_with_emission(
218        &self,
219        call: &ToolCall,
220        snapshot: &ToolSnapshot,
221        tx: &Sender<AgentEvent>,
222    ) -> ToolResult {
223        match snapshot.get(&call.name) {
224            Some(entry) => {
225                self.retry_policy
226                    .execute_with_retry_and_emission(&entry.func, &call.arguments, tx, &call.id)
227                    .await
228            }
229            None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
230        }
231    }
232}
233
234/// 使用预解析的快照批量执行 tool_calls。
235///
236/// 这是动态目录模式的核心执行函数。
237///
238/// # 用法
239///
240/// ```ignore
241/// let snapshot = executor.snapshot().await;
242/// let result = execute_batch_with(&tool_calls, &snapshot, &executor.retry_policy()).await;
243/// ```
244///
245/// # ParallelSafety 契约
246///
247/// - `Safe`: 全并发(每个 tool 独立 spawn)
248/// - `CategoryExclusive(cat)`: 组内串行,组间并发
249/// - `Exclusive`: 全串行
250///
251/// # 一致性保证
252///
253/// `snapshot` 快照在函数执行期间固定不变,不会因目录刷新而漂移。
254pub async fn execute_batch_with(
255    calls: &[ToolCall],
256    snapshot: &ToolSnapshot,
257    retry_policy: &RetryPolicy,
258) -> BatchExecutionResult {
259    if calls.is_empty() {
260        return BatchExecutionResult {
261            results: Vec::new(),
262            panicked: false,
263        };
264    }
265
266    // 分组时保留原始索引
267    let mut safe_calls: Vec<(usize, ToolCall)> = Vec::new();
268    let mut category_calls: std::collections::HashMap<ToolCategory, Vec<(usize, ToolCall)>> =
269        std::collections::HashMap::new();
270    let mut exclusive_calls: Vec<(usize, ToolCall)> = Vec::new();
271
272    for (idx, call) in calls.iter().enumerate() {
273        let safety = snapshot
274            .get(&call.name)
275            .map(|t| t.safety.clone())
276            .unwrap_or(ParallelSafety::Exclusive);
277
278        match safety {
279            ParallelSafety::Safe => safe_calls.push((idx, call.clone())),
280            ParallelSafety::CategoryExclusive => {
281                if let Some(cat) = snapshot.get(&call.name).and_then(|t| t.category.clone()) {
282                    category_calls
283                        .entry(cat)
284                        .or_default()
285                        .push((idx, call.clone()));
286                } else {
287                    exclusive_calls.push((idx, call.clone()));
288                }
289            }
290            ParallelSafety::Exclusive => exclusive_calls.push((idx, call.clone())),
291        }
292    }
293
294    // 构建 group handles
295    let mut group_handles: Vec<tokio::task::JoinHandle<Vec<(usize, Message)>>> = Vec::new();
296    let mut group_indices: Vec<Vec<usize>> = Vec::new();
297
298    let snapshot = Arc::new(snapshot.clone_for_spawn());
299    let retry_policy = retry_policy.clone();
300
301    // Safe: 每个 tool 独立 spawn(全并发)
302    if !safe_calls.is_empty() {
303        let s = Arc::clone(&snapshot);
304        let rp = retry_policy.clone();
305        let indices: Vec<usize> = safe_calls.iter().map(|(i, _)| *i).collect();
306        group_handles.push(tokio::spawn(async move {
307            run_parallel_indexed_with(&s, &rp, safe_calls).await
308        }));
309        group_indices.push(indices);
310    }
311
312    // CategoryExclusive: 按 category 分组,组内串行、组间并发
313    for group_calls in category_calls.into_values() {
314        let s = Arc::clone(&snapshot);
315        let rp = retry_policy.clone();
316        let indices: Vec<usize> = group_calls.iter().map(|(i, _)| *i).collect();
317        group_handles.push(tokio::spawn(async move {
318            run_serial_indexed_with(&s, &rp, group_calls).await
319        }));
320        group_indices.push(indices);
321    }
322
323    // Exclusive: 全部串行,一个 task
324    if !exclusive_calls.is_empty() {
325        let s = Arc::clone(&snapshot);
326        let rp = retry_policy.clone();
327        let indices: Vec<usize> = exclusive_calls.iter().map(|(i, _)| *i).collect();
328        group_handles.push(tokio::spawn(async move {
329            run_serial_indexed_with(&s, &rp, exclusive_calls).await
330        }));
331        group_indices.push(indices);
332    }
333
334    // 按原始索引回填结果;panic 的 group 按索引列表精准回填错误
335    let mut results: Vec<Option<Message>> = vec![None; calls.len()];
336    let mut panicked = false;
337    let all_handles = futures_util::future::join_all(group_handles).await;
338
339    for (handle_result, indices) in all_handles.into_iter().zip(group_indices.into_iter()) {
340        match handle_result {
341            Ok(indexed_messages) => {
342                for (idx, msg) in indexed_messages {
343                    results[idx] = Some(msg);
344                }
345            }
346            Err(join_err) => {
347                panicked = true;
348                for idx in indices {
349                    let call = &calls[idx];
350                    results[idx] = Some(Message::tool_result(
351                        call,
352                        &Err(ToolError {
353                            kind: ToolErrorKind::Internal,
354                            message: format!("tool group task panicked: {join_err}"),
355                        }),
356                    ));
357                }
358            }
359        }
360    }
361
362    BatchExecutionResult {
363        results: results.into_iter().flatten().collect(),
364        panicked,
365    }
366}
367
368// ─── 内部辅助:快照克隆 ──────────────────────────────────────────
369
370impl ToolSnapshot {
371    /// 克隆内部工具映射,供 spawn 使用。
372    pub fn clone_for_spawn(&self) -> Arc<indexmap::IndexMap<String, ToolRegistration>> {
373        self.tools.clone()
374    }
375}
376
377// ─── Safe group: 全并发 ──────────────────────────────────────────
378
379async fn run_parallel_indexed_with(
380    tools: &Arc<indexmap::IndexMap<String, ToolRegistration>>,
381    retry_policy: &RetryPolicy,
382    calls: Vec<(usize, ToolCall)>,
383) -> Vec<(usize, Message)> {
384    let handles: Vec<_> = calls
385        .iter()
386        .map(|(idx, call)| {
387            let tools = Arc::clone(tools);
388            let rp = retry_policy.clone();
389            let call = call.clone();
390            let idx = *idx;
391            tokio::spawn(async move {
392                let result = match tools.get(&call.name) {
393                    Some(entry) => rp.execute_with_retry(&entry.func, &call.arguments).await,
394                    None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
395                };
396                (idx, Message::tool_result(&call, &result))
397            })
398        })
399        .collect();
400
401    let raw = futures_util::future::join_all(handles).await;
402    raw.into_iter()
403        .zip(calls.into_iter())
404        .map(|(h, (idx, call))| match h {
405            Ok((_, msg)) => (idx, msg),
406            Err(join_err) => (
407                idx,
408                Message::tool_result(
409                    &call,
410                    &Err(ToolError {
411                        kind: ToolErrorKind::Internal,
412                        message: format!("tool '{}' task panicked: {join_err}", call.name),
413                    }),
414                ),
415            ),
416        })
417        .collect()
418}
419
420// ─── 组内串行 ────────────────────────────────────────────────────
421
422async fn run_serial_indexed_with(
423    tools: &Arc<indexmap::IndexMap<String, ToolRegistration>>,
424    retry_policy: &RetryPolicy,
425    calls: Vec<(usize, ToolCall)>,
426) -> Vec<(usize, Message)> {
427    let mut results = Vec::with_capacity(calls.len());
428    for (idx, call) in calls {
429        let exec_result = match tools.get(&call.name) {
430            Some(entry) => {
431                retry_policy
432                    .execute_with_retry(&entry.func, &call.arguments)
433                    .await
434            }
435            None => Err(ToolError::not_found(format!("unknown tool: {}", call.name))),
436        };
437        results.push((idx, Message::tool_result(&call, &exec_result)));
438    }
439    results
440}