Skip to main content

lellm_core/
tool.rs

1//! 工具系统 — 协议层 + 可执行工具描述。
2//!
3//! 本模块定义了工具的基础类型,不依赖任何运行时(tokio/futures)。
4//!
5//! **分层:**
6//! - 协议层:`ToolArgs`, `ToolDefinition`, `ParallelSafety`, `ToolCategory`
7//! - 可执行描述:`ExecutableTool`(定义 + 执行器,但不负责调度/重试/目录)
8//!
9//! `ExecutableTool` 是 `#[tool]` 宏的产物,可在 graph 层直接使用。
10//! 真正的运行时(lookup, dispatch, retry, parallel, snapshot)全部留给 lellm-agent。
11
12use std::borrow::Cow;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16
17use crate::ToolResult;
18
19// ─── ToolArgParser ──────────────────────────────────────────────
20
21/// 工具参数解析 trait — 将原始 JSON Value 反序列化为强类型结构体。
22///
23/// **为什么需要这个 trait?**
24/// - `#[tool]` 宏生成的代码不知道 `serde_json` 的存在
25/// - 宏只依赖稳定的 `ToolArgParser::parse()` API
26/// - 所有解析策略(JSON、MessagePack、CBOR…)集中在 core 层
27/// - 以后更换序列化格式,只需修改此 trait 的实现
28///
29/// **依赖方向:**
30/// ```text
31/// lellm-derive
32///       │
33///       ▼
34/// ToolArgParser::parse()
35///       │
36///       ▼
37/// lellm-core
38///       │
39///       ▼
40/// serde_json
41/// ```
42pub trait ToolArgParser: Sized {
43    /// 从原始 JSON Value 解析工具参数。
44    ///
45    /// 解析失败时返回 `serde_json::Error`,调用方负责转换为 `ToolError`。
46    fn parse(value: serde_json::Value) -> Result<Self, serde_json::Error>;
47}
48
49impl<T> ToolArgParser for T
50where
51    T: for<'de> serde::Deserialize<'de>,
52{
53    fn parse(value: serde_json::Value) -> Result<Self, serde_json::Error> {
54        serde_json::from_value(value)
55    }
56}
57
58// ─── ToolArgs ───────────────────────────────────────────────────
59
60/// 工具参数 trait — 由 `#[tool]` 宏自动生成。
61///
62/// 实现了此 trait 的结构体,即可通过 `tool_definition()` 方法
63/// 自动获得 `ToolDefinition`(含 JSON Schema)。
64///
65/// # 示例
66/// ```ignore
67/// use lellm_derive::tool;
68///
69/// #[tool(name = "search", description = "搜索互联网信息")]
70/// async fn search(query: String, limit: u32) -> String {
71///     format!("results for {}", query)
72/// }
73/// // 生成 SearchArgs struct + search_tool() 工厂函数
74/// ```
75pub trait ToolArgs: ToolArgParser {
76    /// 工具名称(蛇形命名)
77    const NAME: &'static str;
78    /// 工具描述
79    const DESCRIPTION: &'static str;
80    /// 由 `#[tool]` 宏生成的 JSON Schema(LazyLock 缓存)
81    fn __schema() -> serde_json::Value;
82    /// 自动生成 ToolDefinition(含 JSON Schema)
83    fn tool_definition() -> ToolDefinition {
84        ToolDefinition {
85            name: Self::NAME.to_string(),
86            description: Self::DESCRIPTION.to_string(),
87            parameters: Self::__schema(),
88            cache_control: None,
89        }
90    }
91}
92
93// ─── ParallelSafety ─────────────────────────────────────────────
94
95/// 工具并行安全分级
96#[derive(Debug, Clone, PartialEq, Eq)]
97pub enum ParallelSafety {
98    /// 可并行执行(默认)
99    Safe,
100    /// 同类别内互斥,类别间可并行
101    CategoryExclusive,
102    /// 全局互斥
103    Exclusive,
104}
105
106// ─── ToolCategory ───────────────────────────────────────────────
107
108/// 工具类别 — 用于 `CategoryExclusive` 的分组
109#[derive(Debug, Clone, PartialEq, Eq, Hash)]
110pub struct ToolCategory(pub Cow<'static, str>);
111
112impl ToolCategory {
113    pub const FILE_IO: Self = Self(Cow::Borrowed("file_io"));
114    pub const NETWORK: Self = Self(Cow::Borrowed("network"));
115    pub const DATABASE: Self = Self(Cow::Borrowed("database"));
116
117    pub fn custom(name: impl Into<Cow<'static, str>>) -> Self {
118        Self(name.into())
119    }
120}
121
122// ─── ToolDefinition ─────────────────────────────────────────────
123
124/// 工具定义(纯数据,协议层)。
125///
126/// Schema 由 `schemars` 在编译期生成,经清洗后存入 `parameters` 字段。
127/// Provider 将此结构序列化后发送给 LLM。
128///
129/// **与 `ExecutableTool` 的区别:**
130/// - `ToolDefinition`(core):纯数据,Provider 序列化发送给 LLM
131/// - `ExecutableTool`(core):可执行,Agent 调用时查找并执行
132#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
133pub struct ToolDefinition {
134    /// 工具名称
135    pub name: String,
136    /// 工具描述
137    pub description: String,
138    /// JSON Schema 参数定义
139    pub parameters: serde_json::Value,
140    /// 缓存控制标记。Anthropic 支持 Tool Definition 级别的缓存。
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub cache_control: Option<crate::message::CacheControl>,
143}
144
145impl ToolDefinition {
146    /// 克隆并设置缓存标记。
147    pub fn with_cache(self, cache: crate::message::CacheControl) -> Self {
148        Self {
149            cache_control: Some(cache),
150            ..self
151        }
152    }
153
154    /// 从 `schemars::JsonSchema` 类型计算并清洗 JSON Schema。
155    ///
156    /// 供 `#[tool]` 宏生成的 `LazyLock` 调用,不在泛型函数中使用 `LazyLock`。
157    ///
158    /// **清洗规则:** 去除 `$schema`, `$id`, `title`, `description` 等根部元数据,
159    /// 保留 `type`, `properties`, `required`, `definitions` 等核心 JSON Schema 字段。
160    pub fn compute_and_clean_schema<S: schemars::JsonSchema>() -> serde_json::Value {
161        let root = schemars::schema_for!(S);
162        let val = serde_json::to_value(&root)
163            .expect("Failed to serialize JsonSchema; this is a bug in schemars");
164        Self::clean_schema(val)
165    }
166
167    /// 清洗 schemars 生成的 RootSchema,去除根部元数据噪音。
168    ///
169    /// 保留 `type`, `properties`, `required`, `definitions`, `additionalProperties`
170    /// 等核心 JSON Schema 字段。Codec 层在此基础上进行 Provider 特定的二次适配。
171    fn clean_schema(mut value: serde_json::Value) -> serde_json::Value {
172        if let Some(obj) = value.as_object_mut() {
173            // 去除标准 JSON Schema 根部的噪声元数据
174            obj.remove("$schema");
175            obj.remove("$id");
176            obj.remove("title");
177            obj.remove("description");
178        }
179        value
180    }
181}
182
183// ─── ToolFn ─────────────────────────────────────────────────────
184
185/// 异步工具执行函数类型 — 接受 JSON 参数,返回 boxed future。
186pub type ToolFn = Arc<
187    dyn Fn(&serde_json::Value) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> + Send + Sync,
188>;
189
190/// 内部辅助 — 将 concrete future coerc 到 trait object。
191///
192/// 用于 `#[tool]` 宏生成的代码,解决 `Box::pin(async move { ... })`
193/// 无法自动 coerc 到 `Pin<Box<dyn Future>>` 的问题。
194#[doc(hidden)]
195pub fn __tool_box<F>(f: F) -> Pin<Box<dyn Future<Output = ToolResult> + Send>>
196where
197    F: Future<Output = ToolResult> + Send + 'static,
198{
199    Box::pin(f)
200}
201
202// ─── ExecutableTool ─────────────────────────────────────────────
203
204/// 可执行的工具 — 定义 + 安全元数据 + 执行器。
205///
206/// **与 `ToolDefinition` 的区别:**
207/// - `ToolDefinition`:纯数据,Provider 序列化发送给 LLM
208/// - `ExecutableTool`:可执行,Agent 调用时查找并执行
209///
210/// **与运行时(lellm-agent)的区别:**
211/// - `ExecutableTool`:描述"这个工具能做什么 + 怎么执行",但不负责调度
212/// - `ToolExecutor` / `ToolCatalog` / `ToolSnapshot`:负责 lookup, dispatch, retry, parallel
213///
214/// 用户通过 `ExecutableTool::safe()` 等工厂方法构造,
215/// 或由 `#[tool]` 宏自动生成。
216#[derive(Clone)]
217pub struct ExecutableTool {
218    /// 工具定义(纯元数据,可被 Provider 序列化)
219    pub definition: ToolDefinition,
220    /// 并行安全级别
221    pub safety: ParallelSafety,
222    /// 工具类别(仅 `CategoryExclusive` 时使用)
223    pub category: Option<ToolCategory>,
224    /// 执行函数(运行时,不被序列化)
225    executor: ToolFn,
226}
227
228impl ExecutableTool {
229    // ─── 访问器 ───────────────────────────────────────────────
230
231    /// 获取工具定义的引用。
232    pub fn definition(&self) -> &ToolDefinition {
233        &self.definition
234    }
235
236    /// 获取并行安全级别。
237    pub fn safety(&self) -> &ParallelSafety {
238        &self.safety
239    }
240
241    /// 获取工具类别(如果有)。
242    pub fn category(&self) -> Option<&ToolCategory> {
243        self.category.as_ref()
244    }
245
246    /// 执行工具调用,返回未来对象。
247    pub fn execute(
248        &self,
249        args: &serde_json::Value,
250    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> {
251        (self.executor)(args)
252    }
253
254    // ─── 低层构造 — 接受原始 ToolFn(用于 MCP bridge 等场景) ──
255
256    /// 从原始执行函数构造。
257    ///
258    /// 用于 MCP bridge 等需要直接控制执行函数的场景。
259    pub fn from_fn(
260        def: ToolDefinition,
261        safety: ParallelSafety,
262        category: Option<ToolCategory>,
263        f: ToolFn,
264    ) -> Self {
265        Self {
266            definition: def,
267            safety,
268            category,
269            executor: f,
270        }
271    }
272
273    // ─── 高层构造 — 原始 JSON 输入 ────────────────────────────
274
275    /// 并行安全(Safe)工具注册。
276    pub fn safe<F, Fut>(def: ToolDefinition, f: F) -> Self
277    where
278        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
279        Fut: Future<Output = ToolResult> + Send + 'static,
280    {
281        Self {
282            definition: def,
283            safety: ParallelSafety::Safe,
284            category: None,
285            executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
286        }
287    }
288
289    /// 分类内互斥(CategoryExclusive)工具注册。
290    pub fn category_exclusive<F, Fut>(def: ToolDefinition, category: ToolCategory, f: F) -> Self
291    where
292        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
293        Fut: Future<Output = ToolResult> + Send + 'static,
294    {
295        Self {
296            definition: def,
297            safety: ParallelSafety::CategoryExclusive,
298            category: Some(category),
299            executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
300        }
301    }
302
303    /// 全局互斥(Exclusive)工具注册。
304    pub fn exclusive<F, Fut>(def: ToolDefinition, f: F) -> Self
305    where
306        F: Fn(&serde_json::Value) -> Fut + Send + Sync + 'static,
307        Fut: Future<Output = ToolResult> + Send + 'static,
308    {
309        Self {
310            definition: def,
311            safety: ParallelSafety::Exclusive,
312            category: None,
313            executor: Arc::new(move |args: &serde_json::Value| Box::pin(f(args))),
314        }
315    }
316
317    // ─── 高层构造 — 强类型输入(自动反序列化) ─────────────────
318
319    /// 强类型便捷构造 — 自动反序列化参数(Safe)。
320    ///
321    /// 与 `safe()` 的区别:闭包接收反序列化后的 `T`,而非原始 `serde_json::Value`。
322    /// 反序列化失败时返回 `ToolErrorKind::InvalidInput`。
323    pub fn safe_fn<T, F, Fut>(def: ToolDefinition, f: F) -> Self
324    where
325        T: ToolArgParser + Send + 'static,
326        F: Fn(T) -> Fut + Send + Sync + 'static,
327        Fut: Future<Output = ToolResult> + Send + 'static,
328    {
329        let f = Arc::new(f);
330        Self::safe(def, move |value| {
331            let f = Arc::clone(&f);
332            let result = T::parse(value.clone());
333            async move {
334                match result {
335                    Ok(parsed) => f(parsed).await,
336                    Err(e) => Err(crate::ToolError::invalid_input(format!(
337                        "invalid tool arguments: {e}"
338                    ))),
339                }
340            }
341        })
342    }
343
344    /// 强类型便捷构造 — 自动反序列化参数(CategoryExclusive)。
345    pub fn category_exclusive_fn<T, F, Fut>(
346        def: ToolDefinition,
347        category: ToolCategory,
348        f: F,
349    ) -> Self
350    where
351        T: ToolArgParser + Send + 'static,
352        F: Fn(T) -> Fut + Send + Sync + 'static,
353        Fut: Future<Output = ToolResult> + Send + 'static,
354    {
355        let f = Arc::new(f);
356        Self::category_exclusive(def, category, move |value| {
357            let f = Arc::clone(&f);
358            let result = T::parse(value.clone());
359            async move {
360                match result {
361                    Ok(parsed) => f(parsed).await,
362                    Err(e) => Err(crate::ToolError::invalid_input(format!(
363                        "invalid tool arguments: {e}"
364                    ))),
365                }
366            }
367        })
368    }
369
370    /// 强类型便捷构造 — 自动反序列化参数(Exclusive)。
371    pub fn exclusive_fn<T, F, Fut>(def: ToolDefinition, f: F) -> Self
372    where
373        T: ToolArgParser + Send + 'static,
374        F: Fn(T) -> Fut + Send + Sync + 'static,
375        Fut: Future<Output = ToolResult> + Send + 'static,
376    {
377        let f = Arc::new(f);
378        Self::exclusive(def, move |value| {
379            let f = Arc::clone(&f);
380            let result = T::parse(value.clone());
381            async move {
382                match result {
383                    Ok(parsed) => f(parsed).await,
384                    Err(e) => Err(crate::ToolError::invalid_input(format!(
385                        "invalid tool arguments: {e}"
386                    ))),
387                }
388            }
389        })
390    }
391}
392
393/// 向后兼容别名 — `ToolRegistration` 已重命名为 `ExecutableTool`。
394#[deprecated(since = "0.5.0", note = "Use `ExecutableTool` instead")]
395pub type ToolRegistration = ExecutableTool;