Skip to main content

lellm_agent/tools/
executor.rs

1//! 工具执行器 — 注册、分派、批量执行、并行安全分级。
2
3use std::borrow::Cow;
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use lellm_core::{Message, ToolCall};
9
10use super::ToolCallResult;
11
12/// 工具安全分级
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ParallelSafety {
15    Safe,
16    CategoryExclusive,
17    Exclusive,
18}
19
20/// 工具类别 — 用于 CategoryExclusive 的分组
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct ToolCategory(pub Cow<'static, str>);
23
24impl ToolCategory {
25    pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
26    pub const NETWORK: Self = Self(Cow::Borrowed("network"));
27    pub const DATABASE: Self = Self(Cow::Borrowed("database"));
28
29    pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
30        Self(name.into())
31    }
32}
33
34/// 工具注册信息(包含安全分级 + 执行函数)。
35pub struct ToolRegistration {
36    safety: ParallelSafety,
37    #[allow(dead_code)]
38    category: Option<ToolCategory>,
39    func: ToolFn,
40}
41
42/// 异步工具函数类型
43type ToolFn = Arc<
44    dyn Fn(&serde_json::Value) -> Pin<Box<dyn std::future::Future<Output = ToolCallResult> + Send>>
45        + Send
46        + Sync,
47>;
48
49impl ToolRegistration {
50    pub fn safe<F, Fut>(f: F) -> Self
51    where
52        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
53        Fut: std::future::Future<Output = ToolCallResult> + Send + 'static,
54    {
55        Self {
56            safety: ParallelSafety::Safe,
57            category: None,
58            func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
59        }
60    }
61
62    pub fn category_exclusive<F, Fut>(category: ToolCategory, f: F) -> Self
63    where
64        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
65        Fut: std::future::Future<Output = ToolCallResult> + Send + 'static,
66    {
67        Self {
68            safety: ParallelSafety::CategoryExclusive,
69            category: Some(category),
70            func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
71        }
72    }
73
74    pub fn exclusive<F, Fut>(f: F) -> Self
75    where
76        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
77        Fut: std::future::Future<Output = ToolCallResult> + Send + 'static,
78    {
79        Self {
80            safety: ParallelSafety::Exclusive,
81            category: None,
82            func: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
83        }
84    }
85}
86
87/// 工具执行器 — 按名称分派 ToolCall 到实际工具函数。
88#[derive(Default)]
89pub struct ToolExecutor {
90    tools: HashMap<String, ToolFn>,
91    safety: HashMap<String, ParallelSafety>,
92}
93
94impl ToolExecutor {
95    pub fn new() -> Self {
96        Self {
97            tools: HashMap::new(),
98            safety: HashMap::new(),
99        }
100    }
101
102    pub fn register(&mut self, name: &str, reg: ToolRegistration) {
103        self.safety.insert(name.to_string(), reg.safety.clone());
104        self.tools.insert(name.to_string(), reg.func);
105    }
106
107    pub fn safety_for(&self, name: &str) -> ParallelSafety {
108        self.safety
109            .get(name)
110            .cloned()
111            .unwrap_or(ParallelSafety::Exclusive)
112    }
113
114    pub async fn execute(&self, call: &ToolCall) -> ToolCallResult {
115        match self.tools.get(&call.name) {
116            Some(tool_fn) => tool_fn(&call.arguments).await,
117            None => ToolCallResult::Err(format!("unknown tool: {}", call.name)),
118        }
119    }
120
121    pub async fn execute_batch(&self, calls: &[ToolCall]) -> Vec<Message> {
122        let mut results = Vec::new();
123        for call in calls {
124            let result = self.execute(call).await;
125            let content = match result {
126                ToolCallResult::Ok(s) => s,
127                ToolCallResult::Err(e) => format!("tool error: {e}"),
128            };
129            results.push(Message::ToolResult {
130                tool_call_id: call.id.clone(),
131                content: lellm_core::text_block(content),
132            });
133        }
134        results
135    }
136
137    /// 按安全分级将 tool_calls 分为可并行和需串行两组
138    pub fn partition_calls(&self, calls: &[ToolCall]) -> (Vec<ToolCall>, Vec<ToolCall>) {
139        let mut safe = Vec::new();
140        let mut exclusive = Vec::new();
141        for call in calls {
142            let safety = self.safety_for(&call.name);
143            match safety {
144                ParallelSafety::Safe => safe.push(call.clone()),
145                ParallelSafety::CategoryExclusive | ParallelSafety::Exclusive => {
146                    exclusive.push(call.clone());
147                }
148            }
149        }
150        (safe, exclusive)
151    }
152}