Skip to main content

katu_core/
tool.rs

1//! # katu_core::tool
2//!
3//! ## 职责
4//! 定义工具系统的数据类型与执行契约。
5//!
6//! ## 设计原则
7//! - **Provider 无关** — 参数使用 JSON Schema(`serde_json::Value`),适配所有 LLM provider
8//! - **Serde 友好** — 数据类型可序列化/反序列化
9//! - **最小 trait** — `Tool` trait 只含必要方法,扩展能力留给上层
10//!
11//! ## 对外接口
12//! - `ToolDefinition` — 发送给 LLM 的工具 schema(name + description + parameters JSON Schema)
13//! - `ToolOutput` — 工具执行结果(content + metadata + is_error)
14//! - `ToolChoice` — 工具选择策略(auto / none / required / specific)
15//! - `Tool` — 工具执行 trait(definition + validate + execute + concurrency_mode)
16//! - `ToolCallContext` — 执行上下文(call_id + cancellation + extra)
17//! - `CancellationToken` — 取消令牌(re-export `tokio_util::sync::CancellationToken`)
18//! - `ConcurrencyMode` — 并发调度标记
19//!
20//! ## 调用者
21//! - `katu-llm` — `LlmRequest` 持有 `Vec<ToolDefinition>` + `ToolChoice`
22//! - `katu-agent` (future) — Agent loop 通过 `Tool` trait 调用工具
23//! - `katu-core::event` — `StreamEvent::ToolResult` 可从 `ToolOutput` 构造
24
25use async_trait::async_trait;
26use serde::{Deserialize, Serialize};
27
28use crate::error::Result;
29use crate::types::ToolCallId;
30
31// ===========================================================================
32// ToolDefinition
33// ===========================================================================
34
35/// 工具定义 — 发送给 LLM 的 schema。
36///
37/// 对应 LLM API 中 `tools` 数组的每个元素,包含名称、描述和参数 JSON Schema。
38/// LLM 据此决定何时调用工具以及如何构造参数。
39///
40/// # Examples
41///
42/// ```
43/// use katu_core::ToolDefinition;
44/// use serde_json::json;
45///
46/// let tool = ToolDefinition::new(
47///     "read_file",
48///     "Read the contents of a file at the given path",
49///     json!({
50///         "type": "object",
51///         "properties": {
52///             "path": { "type": "string", "description": "File path" }
53///         },
54///         "required": ["path"]
55///     }),
56/// );
57/// assert_eq!(tool.name, "read_file");
58/// ```
59#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60pub struct ToolDefinition {
61    /// 工具名称 — LLM 在 tool_call 中引用的唯一标识。
62    ///
63    /// 命名约定:`snake_case`,如 `"read_file"`, `"bash"`, `"web_search"`。
64    pub name: String,
65
66    /// 工具描述 — LLM 据此决定何时以及为什么调用此工具。
67    ///
68    /// 应该清晰描述工具的功能、适用场景和限制。
69    pub description: String,
70
71    /// 参数的 JSON Schema — 定义工具接受的输入格式。
72    ///
73    /// 必须是一个 `{"type": "object", "properties": {...}}` 形式的 JSON Schema。
74    /// LLM 根据此 schema 构造 `tool_call.arguments`。
75    pub parameters: serde_json::Value,
76}
77
78impl ToolDefinition {
79    /// 创建新的工具定义。
80    pub fn new(
81        name: impl Into<String>,
82        description: impl Into<String>,
83        parameters: serde_json::Value,
84    ) -> Self {
85        Self {
86            name: name.into(),
87            description: description.into(),
88            parameters,
89        }
90    }
91
92    /// 创建无参数的工具定义。
93    ///
94    /// 等价于 `parameters: {"type": "object", "properties": {}}`。
95    pub fn no_params(name: impl Into<String>, description: impl Into<String>) -> Self {
96        Self {
97            name: name.into(),
98            description: description.into(),
99            parameters: serde_json::json!({
100                "type": "object",
101                "properties": {}
102            }),
103        }
104    }
105}
106
107// ===========================================================================
108// ToolOutput
109// ===========================================================================
110
111/// 工具执行结果 — 返回给 agent loop 的输出。
112///
113/// 设计参考 OpenCode 的 `ExecuteResult`:
114/// - `content` 总是 string(LLM 只理解文本)
115/// - `metadata` 为结构化数据(UI/日志/遥测用,不发送给 LLM)
116/// - `is_error` 标记非抛异常的失败(如工具内部捕获的错误)
117///
118/// # Examples
119///
120/// ```
121/// use katu_core::ToolOutput;
122/// use serde_json::json;
123///
124/// // 成功结果
125/// let output = ToolOutput::success("File contents here");
126/// assert!(!output.is_error);
127///
128/// // 错误结果
129/// let output = ToolOutput::error("Permission denied: /etc/shadow");
130/// assert!(output.is_error);
131/// ```
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub struct ToolOutput {
134    /// 标题 — UI 显示用的简短描述。
135    ///
136    /// 例如 `"Read file: src/main.rs"`, `"Bash: ls -la"`。
137    #[serde(default)]
138    pub title: String,
139
140    /// 主输出内容 — 发送回 LLM 的文本。
141    ///
142    /// 这是 LLM 在下一轮推理中看到的 tool result 内容。
143    pub content: String,
144
145    /// 结构化元数据 — UI/日志/遥测用,**不**发送给 LLM。
146    ///
147    /// 例如执行耗时、文件路径、diff 统计等。
148    #[serde(default = "default_metadata")]
149    pub metadata: serde_json::Value,
150
151    /// 是否为错误结果。
152    ///
153    /// `true` 时 agent loop 将 `content` 作为错误信息反馈给 LLM,
154    /// LLM 可据此修正策略。区别于 `Err(...)` 的不可恢复错误,
155    /// `is_error = true` 表示工具执行完成但结果是失败的。
156    #[serde(default)]
157    pub is_error: bool,
158}
159
160fn default_metadata() -> serde_json::Value {
161    serde_json::Value::Object(serde_json::Map::new())
162}
163
164impl ToolOutput {
165    /// 创建成功的工具输出。
166    pub fn success(content: impl Into<String>) -> Self {
167        Self {
168            title: String::new(),
169            content: content.into(),
170            metadata: default_metadata(),
171            is_error: false,
172        }
173    }
174
175    /// 创建带标题的成功工具输出。
176    pub fn success_with_title(title: impl Into<String>, content: impl Into<String>) -> Self {
177        Self {
178            title: title.into(),
179            content: content.into(),
180            metadata: default_metadata(),
181            is_error: false,
182        }
183    }
184
185    /// 创建错误的工具输出。
186    ///
187    /// 不同于 `Err(...)` — 这里工具执行完成了,但结果是失败的。
188    /// LLM 会看到错误信息并可以据此调整策略。
189    pub fn error(content: impl Into<String>) -> Self {
190        Self {
191            title: String::new(),
192            content: content.into(),
193            metadata: default_metadata(),
194            is_error: true,
195        }
196    }
197
198    /// 设置元数据(builder 模式)。
199    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
200        self.metadata = metadata;
201        self
202    }
203
204    /// 设置标题(builder 模式)。
205    pub fn with_title(mut self, title: impl Into<String>) -> Self {
206        self.title = title.into();
207        self
208    }
209}
210
211// ===========================================================================
212// ToolChoice
213// ===========================================================================
214
215/// 工具选择策略 — 控制 LLM 是否以及如何使用工具。
216///
217/// 发送给 LLM 的 `tool_choice` 参数,不同 provider 有不同的映射方式,
218/// 由 provider adapter 负责转换。
219///
220/// # Examples
221///
222/// ```
223/// use katu_core::ToolChoice;
224///
225/// let choice = ToolChoice::Auto;
226/// assert!(choice.allows_tools());
227///
228/// let choice = ToolChoice::None;
229/// assert!(!choice.allows_tools());
230///
231/// let choice = ToolChoice::specific("bash");
232/// assert!(choice.allows_tools());
233/// assert_eq!(choice.required_tool(), Some("bash"));
234/// ```
235#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
236#[serde(tag = "type", rename_all = "snake_case")]
237pub enum ToolChoice {
238    /// 模型自行决定是否使用工具。
239    #[default]
240    Auto,
241
242    /// 禁止使用工具 — 模型只生成文本。
243    None,
244
245    /// 必须使用工具 — 模型必须至少调用一个工具。
246    Required,
247
248    /// 强制使用指定工具。
249    Specific {
250        /// 必须调用的工具名称。
251        name: String,
252    },
253}
254
255impl ToolChoice {
256    /// 创建强制使用指定工具的选择策略。
257    pub fn specific(name: impl Into<String>) -> Self {
258        Self::Specific { name: name.into() }
259    }
260
261    /// 是否允许工具调用(`None` 时不允许)。
262    pub fn allows_tools(&self) -> bool {
263        !matches!(self, Self::None)
264    }
265
266    /// 如果是 `Specific`,返回指定的工具名称。
267    pub fn required_tool(&self) -> Option<&str> {
268        match self {
269            Self::Specific { name } => Some(name.as_str()),
270            _ => Option::None,
271        }
272    }
273
274    /// 是否强制使用工具(`Required` 或 `Specific`)。
275    pub fn is_forced(&self) -> bool {
276        matches!(self, Self::Required | Self::Specific { .. })
277    }
278}
279
280
281// ===========================================================================
282// CancellationToken (re-export)
283// ===========================================================================
284
285/// 取消令牌 — re-export `tokio_util::sync::CancellationToken`。
286///
287/// Agent loop 在需要取消工具时调用 `cancel()`,
288/// 工具通过 `.cancelled().await` 异步等待,或 `is_cancelled()` 同步轮询。
289///
290/// ## 核心 API
291/// - `CancellationToken::new()` — 创建
292/// - `.cancel()` — 触发取消
293/// - `.is_cancelled()` — 同步检查
294/// - `.cancelled().await` — 异步等待
295/// - `.child_token()` — 层级取消(Agent → Runner → Tool)
296///
297/// # Examples
298///
299/// ```
300/// use katu_core::CancellationToken;
301///
302/// let token = CancellationToken::new();
303/// let child = token.child_token();
304///
305/// assert!(!token.is_cancelled());
306/// token.cancel();
307/// assert!(child.is_cancelled());
308/// ```
309pub use tokio_util::sync::CancellationToken;
310
311// ===========================================================================
312// ConcurrencyMode
313// ===========================================================================
314
315/// 工具并发模式 — Agent loop 调度同批次 tool_call 时的行为。
316///
317/// 当 LLM 在一次响应中请求多个 tool_call 时:
318/// - 所有 `Shared` 工具可以并行执行
319/// - 遇到 `Exclusive` 工具时,等待前面的工具完成后独占执行
320///
321/// # Examples
322///
323/// ```
324/// use katu_core::ConcurrencyMode;
325///
326/// let mode = ConcurrencyMode::default();
327/// assert_eq!(mode, ConcurrencyMode::Shared);
328/// ```
329#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
330#[serde(rename_all = "snake_case")]
331pub enum ConcurrencyMode {
332    /// 可与其他 Shared 工具并行执行(如 read_file, grep)。
333    #[default]
334    Shared,
335    /// 独占执行,不与其他工具并行(如 write_file, bash)。
336    Exclusive,
337}
338
339// ===========================================================================
340// ToolCallContext
341// ===========================================================================
342
343/// 工具执行上下文 — Agent loop 构造后传入 `Tool::execute`。
344///
345/// ## Builder 模式
346/// 必填字段通过 `new(call_id)` 提供,可选字段通过 `with_*` 链式设置。
347///
348/// # Examples
349///
350/// ```
351/// use katu_core::{ToolCallContext, ToolCallId, CancellationToken};
352/// use serde_json::json;
353///
354/// let token = CancellationToken::new();
355/// let ctx = ToolCallContext::new(ToolCallId::new("call_1"))
356///     .with_cancellation(token.clone())
357///     .with_extra(json!({"cwd": "/home/user/project"}));
358///
359/// assert_eq!(ctx.call_id.as_str(), "call_1");
360/// assert!(!ctx.cancellation.is_cancelled());
361/// ```
362pub struct ToolCallContext {
363    /// 本次 tool_call 的唯一标识(由 LLM 或 agent loop 分配)。
364    pub call_id: ToolCallId,
365
366    /// 取消令牌 — 工具在长循环中检查是否需要提前退出。
367    pub cancellation: CancellationToken,
368
369    /// 扩展数据 — 上层应用注入的额外上下文。
370    ///
371    /// 如当前工作目录、环境变量、session 信息等。
372    /// 基座不预设上层需求,工具按 key 取值。
373    pub extra: serde_json::Value,
374}
375
376impl ToolCallContext {
377    /// 创建上下文 — 只需 call_id,其余使用默认值。
378    pub fn new(call_id: ToolCallId) -> Self {
379        Self {
380            call_id,
381            cancellation: CancellationToken::new(),
382            extra: serde_json::Value::Object(serde_json::Map::new()),
383        }
384    }
385
386    /// 注入已有的取消令牌(Agent loop 持有另一端用于触发取消)。
387    pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
388        self.cancellation = token;
389        self
390    }
391
392    /// 注入扩展数据。
393    pub fn with_extra(mut self, extra: serde_json::Value) -> Self {
394        self.extra = extra;
395        self
396    }
397}
398
399// ===========================================================================
400// Tool trait
401// ===========================================================================
402
403/// 工具执行 trait — 所有可被 Agent 调用的工具必须实现此 trait。
404///
405/// ## 方法职责
406/// - `definition()` — "我是谁"(schema 信息,注册给 LLM)
407/// - `validate()` — "参数合法吗"(可选的补充校验,默认通过)
408/// - `execute()` — "执行动作"(核心业务逻辑)
409/// - `concurrency_mode()` — "我的调度约束"(给 agent loop 的调度提示)
410/// - `permission_request()` — "细粒度权限"(动态构建权限请求,覆盖默认逻辑)
411///
412/// ## 返回值约定
413/// - `Ok(ToolOutput { is_error: false })` — 工具成功
414/// - `Ok(ToolOutput { is_error: true })` — 业务失败(如"文件不存在"),
415///   agent loop 回传 LLM 让模型调整
416/// - `Err(Error::Cancelled)` — 被取消
417/// - `Err(Error::Internal(..))` — 工具崩溃,agent loop 决定重试或终止
418///
419/// ## Object Safety
420/// 通过 `#[async_trait]` 实现 dyn dispatch,支持 `Arc<dyn Tool>` 在 registry 中存储。
421///
422/// # Examples
423///
424/// ```
425/// use async_trait::async_trait;
426/// use katu_core::{Tool, ToolDefinition, ToolOutput, ToolCallContext, ConcurrencyMode, Result};
427/// use serde_json::{json, Value};
428///
429/// struct GetTimeTool;
430///
431/// static GET_TIME_DEF: std::sync::LazyLock<ToolDefinition> = std::sync::LazyLock::new(|| {
432///     ToolDefinition::no_params("get_time", "Get current UTC time")
433/// });
434///
435/// #[async_trait]
436/// impl Tool for GetTimeTool {
437///     fn definition(&self) -> &ToolDefinition {
438///         &GET_TIME_DEF
439///     }
440///
441///     async fn execute(&self, _args: Value, _ctx: &ToolCallContext) -> Result<ToolOutput> {
442///         Ok(ToolOutput::success("2025-01-01T00:00:00Z"))
443///     }
444/// }
445/// ```
446#[async_trait]
447pub trait Tool: Send + Sync {
448    /// 返回工具定义 — 名称、描述、参数 JSON Schema。
449    fn definition(&self) -> &ToolDefinition;
450
451    /// 参数补充验证 — 在 execute 前调用。
452    ///
453    /// 默认实现返回 `Ok(())` — 信任 LLM 已按 JSON Schema 生成参数。
454    /// 需要额外校验的工具(如路径安全检查)可覆盖此方法。
455    ///
456    /// 返回 `Err` 时,agent loop 将错误信息作为 tool result 返回给 LLM,
457    /// 不会调用 `execute`。
458    async fn validate(&self, _args: &serde_json::Value, _ctx: &ToolCallContext) -> Result<()> {
459        Ok(())
460    }
461
462    /// 执行工具 — 核心业务逻辑。
463    ///
464    /// ## 取消约定
465    /// 长运行工具应周期性检查 `ctx.cancellation.is_cancelled()`,
466    /// 检测到取消后返回 `Err(Error::Cancelled)` 并清理资源。
467    async fn execute(&self, args: serde_json::Value, ctx: &ToolCallContext) -> Result<ToolOutput>;
468
469    /// 并发模式 — 告知 agent loop 此工具的调度约束。
470    ///
471    /// 默认 `Shared` — 可与其他工具并行。
472    /// 写入类工具应返回 `Exclusive`。
473    fn concurrency_mode(&self) -> ConcurrencyMode {
474        ConcurrencyMode::Shared
475    }
476
477    /// 权限 key — 权限规则匹配时使用的标识。
478    ///
479    /// 默认返回工具名称。某些工具可能细分权限:
480    /// 如 bash 工具根据子命令前缀返回 `"bash"` 或 `"bash:git"`。
481    fn permission_key(&self) -> &str {
482        &self.definition().name
483    }
484
485    /// 工具级权限检查 — 在规则引擎求值后、用户交互前调用。
486    ///
487    /// ## 用途
488    /// 工具实现可据此检查:
489    /// - 路径安全性(如禁止写入 `.git/`、`.katu/` 目录)
490    /// - 命令安全性(如禁止 `rm -rf /`)
491    /// - URL 白名单等
492    ///
493    /// ## 返回值
494    /// - `Passthrough` — 不做判断,交由规则引擎(**默认**)
495    /// - `Allow` — 工具认为此操作安全
496    /// - `Deny { message }` — 工具明确拒绝
497    /// - `Ask { message }` — 工具建议询问用户
498    ///
499    /// ## 与 validate 的区别
500    /// - `validate()` = 参数**格式**是否合法(类型检查)
501    /// - `check_permissions()` = 操作**是否被允许**(授权检查)
502    ///
503    /// ## 调用顺序
504    /// ```text
505    /// Hook(PreToolUse) → check_permissions() → 规则引擎 → 用户交互 → validate() → execute()
506    /// ```
507    fn check_permissions(
508        &self,
509        _args: &serde_json::Value,
510        _ctx: &ToolCallContext,
511    ) -> crate::permission::PermissionResult {
512        crate::permission::PermissionResult::Passthrough
513    }
514
515    /// 构建细粒度权限请求 — 工具可据此定制 permission key 和 pattern。
516    ///
517    /// 默认返回 `None` — 框架使用 `permission_key()` + `args.to_string()` 的默认逻辑。
518    ///
519    /// ## 用途
520    /// 需要细粒度权限控制的工具(如 bash)可覆盖此方法,提供:
521    /// - 更精确的 permission key(如 `"bash:git"` 而非 `"bash"`)
522    /// - 有意义的 pattern(如 `"git push origin main"` 而非序列化后的 JSON)
523    /// - always-allow 模式(如 `"git push *"`)
524    /// - UI 展示用的元数据
525    ///
526    /// ## 与 permission_key() 的关系
527    /// - `permission_key()` 返回固定的 `&str`,适合简单工具
528    /// - `permission_request()` 返回动态构造的 `PermissionRequest`,适合需要
529    ///   根据参数内容变化 key 和 pattern 的复杂工具
530    /// - 如果两者都实现,`permission_request()` 优先
531    ///
532    /// ## 调用时机
533    /// 在 `check_permissions()` 返回 `Passthrough` 后、Ruleset 求值前调用。
534    ///
535    /// # Examples
536    ///
537    /// ```ignore
538    /// fn permission_request(
539    ///     &self,
540    ///     args: &serde_json::Value,
541    ///     _ctx: &ToolCallContext,
542    /// ) -> Option<crate::permission::PermissionRequest> {
543    ///     let command = args["command"].as_str()?;
544    ///     Some(crate::permission::PermissionRequest::new("bash:git", command)
545    ///         .with_tool_name("bash")
546    ///         .with_always_allow(vec!["git push *"]))
547    /// }
548    /// ```
549    fn permission_request(
550        &self,
551        _args: &serde_json::Value,
552        _ctx: &ToolCallContext,
553    ) -> Option<crate::permission::PermissionRequest> {
554        None
555    }
556}
557
558// ===========================================================================
559// Tests
560// ===========================================================================
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use std::sync::Arc;
566    use serde_json::json;
567
568    // -- ToolDefinition --
569
570    #[test]
571    fn test_tool_definition_new() {
572        let def = ToolDefinition::new(
573            "read_file",
574            "Read file contents",
575            json!({
576                "type": "object",
577                "properties": {
578                    "path": { "type": "string" }
579                },
580                "required": ["path"]
581            }),
582        );
583        assert_eq!(def.name, "read_file");
584        assert_eq!(def.description, "Read file contents");
585        assert!(def.parameters["properties"]["path"]["type"]
586            .as_str()
587            .unwrap()
588            == "string");
589    }
590
591    #[test]
592    fn test_tool_definition_no_params() {
593        let def = ToolDefinition::no_params("get_time", "Get current time");
594        assert_eq!(def.name, "get_time");
595        assert_eq!(def.parameters["type"], "object");
596        assert!(def.parameters["properties"]
597            .as_object()
598            .unwrap()
599            .is_empty());
600    }
601
602    #[test]
603    fn test_tool_definition_serde_roundtrip() {
604        let def = ToolDefinition::new(
605            "bash",
606            "Run a shell command",
607            json!({
608                "type": "object",
609                "properties": {
610                    "command": { "type": "string" }
611                },
612                "required": ["command"]
613            }),
614        );
615        let json_str = serde_json::to_string(&def).unwrap();
616        let restored: ToolDefinition = serde_json::from_str(&json_str).unwrap();
617        assert_eq!(def, restored);
618    }
619
620    // -- ToolOutput --
621
622    #[test]
623    fn test_tool_output_success() {
624        let out = ToolOutput::success("hello world");
625        assert_eq!(out.content, "hello world");
626        assert!(!out.is_error);
627        assert!(out.title.is_empty());
628    }
629
630    #[test]
631    fn test_tool_output_success_with_title() {
632        let out = ToolOutput::success_with_title("Read file", "file contents");
633        assert_eq!(out.title, "Read file");
634        assert_eq!(out.content, "file contents");
635        assert!(!out.is_error);
636    }
637
638    #[test]
639    fn test_tool_output_error() {
640        let out = ToolOutput::error("not found");
641        assert_eq!(out.content, "not found");
642        assert!(out.is_error);
643    }
644
645    #[test]
646    fn test_tool_output_builder() {
647        let out = ToolOutput::success("ok")
648            .with_title("Done")
649            .with_metadata(json!({"elapsed_ms": 42}));
650        assert_eq!(out.title, "Done");
651        assert_eq!(out.metadata["elapsed_ms"], 42);
652        assert!(!out.is_error);
653    }
654
655    #[test]
656    fn test_tool_output_serde_roundtrip() {
657        let out = ToolOutput::success_with_title("Read", "contents")
658            .with_metadata(json!({"lines": 100}));
659        let json_str = serde_json::to_string(&out).unwrap();
660        let restored: ToolOutput = serde_json::from_str(&json_str).unwrap();
661        assert_eq!(out, restored);
662    }
663
664    #[test]
665    fn test_tool_output_serde_defaults() {
666        // 反序列化时缺少可选字段应该使用默认值
667        let json_str = r#"{"content":"hello"}"#;
668        let out: ToolOutput = serde_json::from_str(json_str).unwrap();
669        assert_eq!(out.content, "hello");
670        assert!(!out.is_error);
671        assert!(out.title.is_empty());
672        assert!(out.metadata.is_object());
673    }
674
675    // -- ToolChoice --
676
677    #[test]
678    fn test_tool_choice_auto_default() {
679        let choice = ToolChoice::default();
680        assert_eq!(choice, ToolChoice::Auto);
681    }
682
683    #[test]
684    fn test_tool_choice_allows_tools() {
685        assert!(ToolChoice::Auto.allows_tools());
686        assert!(!ToolChoice::None.allows_tools());
687        assert!(ToolChoice::Required.allows_tools());
688        assert!(ToolChoice::specific("bash").allows_tools());
689    }
690
691    #[test]
692    fn test_tool_choice_required_tool() {
693        assert_eq!(ToolChoice::Auto.required_tool(), Option::None);
694        assert_eq!(ToolChoice::None.required_tool(), Option::None);
695        assert_eq!(ToolChoice::Required.required_tool(), Option::None);
696        assert_eq!(ToolChoice::specific("bash").required_tool(), Some("bash"));
697    }
698
699    #[test]
700    fn test_tool_choice_is_forced() {
701        assert!(!ToolChoice::Auto.is_forced());
702        assert!(!ToolChoice::None.is_forced());
703        assert!(ToolChoice::Required.is_forced());
704        assert!(ToolChoice::specific("bash").is_forced());
705    }
706
707    #[test]
708    fn test_tool_choice_serde_roundtrip() {
709        for choice in [
710            ToolChoice::Auto,
711            ToolChoice::None,
712            ToolChoice::Required,
713            ToolChoice::specific("read_file"),
714        ] {
715            let json_str = serde_json::to_string(&choice).unwrap();
716            let restored: ToolChoice = serde_json::from_str(&json_str).unwrap();
717            assert_eq!(choice, restored);
718        }
719    }
720
721    #[test]
722    fn test_tool_choice_serde_format() {
723        let json_str = serde_json::to_string(&ToolChoice::Auto).unwrap();
724        assert!(json_str.contains(r#""type":"auto""#));
725
726        let json_str = serde_json::to_string(&ToolChoice::specific("bash")).unwrap();
727        assert!(json_str.contains(r#""type":"specific""#));
728        assert!(json_str.contains(r#""name":"bash""#));
729    }
730
731    // -- CancellationToken --
732
733    #[test]
734    fn test_cancellation_token_new() {
735        let token = CancellationToken::new();
736        assert!(!token.is_cancelled());
737    }
738
739    #[test]
740    fn test_cancellation_token_cancel() {
741        let token = CancellationToken::new();
742        token.cancel();
743        assert!(token.is_cancelled());
744    }
745
746    #[test]
747    fn test_cancellation_token_clone_shares_state() {
748        let token = CancellationToken::new();
749        let token2 = token.clone();
750
751        assert!(!token.is_cancelled());
752        assert!(!token2.is_cancelled());
753
754        token.cancel();
755
756        assert!(token.is_cancelled());
757        assert!(token2.is_cancelled());
758    }
759
760    #[test]
761    fn test_cancellation_token_child() {
762        let parent = CancellationToken::new();
763        let child = parent.child_token();
764        assert!(!child.is_cancelled());
765        parent.cancel();
766        assert!(child.is_cancelled());
767    }
768
769    // -- ConcurrencyMode --
770
771    #[test]
772    fn test_concurrency_mode_default() {
773        assert_eq!(ConcurrencyMode::default(), ConcurrencyMode::Shared);
774    }
775
776    #[test]
777    fn test_concurrency_mode_serde_roundtrip() {
778        for mode in [ConcurrencyMode::Shared, ConcurrencyMode::Exclusive] {
779            let json_str = serde_json::to_string(&mode).unwrap();
780            let restored: ConcurrencyMode = serde_json::from_str(&json_str).unwrap();
781            assert_eq!(mode, restored);
782        }
783    }
784
785    #[test]
786    fn test_concurrency_mode_serde_format() {
787        assert_eq!(
788            serde_json::to_string(&ConcurrencyMode::Shared).unwrap(),
789            r#""shared""#
790        );
791        assert_eq!(
792            serde_json::to_string(&ConcurrencyMode::Exclusive).unwrap(),
793            r#""exclusive""#
794        );
795    }
796
797    // -- ToolCallContext --
798
799    #[test]
800    fn test_tool_call_context_new() {
801        let ctx = ToolCallContext::new(ToolCallId::new("call_1"));
802        assert_eq!(ctx.call_id.as_str(), "call_1");
803        assert!(!ctx.cancellation.is_cancelled());
804        assert!(ctx.extra.is_object());
805    }
806
807    #[test]
808    fn test_tool_call_context_builder() {
809        let token = CancellationToken::new();
810        let ctx = ToolCallContext::new(ToolCallId::new("call_2"))
811            .with_cancellation(token.clone())
812            .with_extra(json!({"cwd": "/tmp", "env": {"DEBUG": "1"}}));
813
814        assert_eq!(ctx.call_id.as_str(), "call_2");
815        assert_eq!(ctx.extra["cwd"], "/tmp");
816        assert_eq!(ctx.extra["env"]["DEBUG"], "1");
817
818        // 共享 token
819        token.cancel();
820        assert!(ctx.cancellation.is_cancelled());
821    }
822
823    // -- Tool trait --
824
825    struct EchoTool;
826
827    static ECHO_DEF: std::sync::LazyLock<ToolDefinition> = std::sync::LazyLock::new(|| {
828        ToolDefinition::new(
829            "echo",
830            "Echoes the input message",
831            json!({
832                "type": "object",
833                "properties": {
834                    "message": { "type": "string" }
835                },
836                "required": ["message"]
837            }),
838        )
839    });
840
841    #[async_trait]
842    impl Tool for EchoTool {
843        fn definition(&self) -> &ToolDefinition {
844            &ECHO_DEF
845        }
846
847        async fn execute(
848            &self,
849            args: serde_json::Value,
850            _ctx: &ToolCallContext,
851        ) -> Result<ToolOutput> {
852            let message = args["message"]
853                .as_str()
854                .unwrap_or("(no message)");
855            Ok(ToolOutput::success(message))
856        }
857    }
858
859    struct ExclusiveTool;
860
861    static EXCLUSIVE_DEF: std::sync::LazyLock<ToolDefinition> = std::sync::LazyLock::new(|| {
862        ToolDefinition::new(
863            "write_file",
864            "Write content to a file",
865            json!({
866                "type": "object",
867                "properties": {
868                    "path": { "type": "string" },
869                    "content": { "type": "string" }
870                },
871                "required": ["path", "content"]
872            }),
873        )
874    });
875
876    #[async_trait]
877    impl Tool for ExclusiveTool {
878        fn definition(&self) -> &ToolDefinition {
879            &EXCLUSIVE_DEF
880        }
881
882        async fn validate(
883            &self,
884            args: &serde_json::Value,
885            _ctx: &ToolCallContext,
886        ) -> Result<()> {
887            let path = args["path"].as_str().unwrap_or("");
888            if path.starts_with("/etc/") {
889                return Err(crate::Error::tool(
890                    "write_file",
891                    _ctx.call_id.clone(),
892                    "cannot write to /etc/",
893                ));
894            }
895            Ok(())
896        }
897
898        async fn execute(
899            &self,
900            args: serde_json::Value,
901            ctx: &ToolCallContext,
902        ) -> Result<ToolOutput> {
903            if ctx.cancellation.is_cancelled() {
904                return Err(crate::Error::Cancelled);
905            }
906            let path = args["path"].as_str().unwrap_or("?");
907            Ok(ToolOutput::success(format!("wrote to {path}"))
908                .with_title(format!("Write: {path}")))
909        }
910
911        fn concurrency_mode(&self) -> ConcurrencyMode {
912            ConcurrencyMode::Exclusive
913        }
914    }
915
916    #[tokio::test]
917    async fn test_tool_echo_execute() {
918        let tool = EchoTool;
919        let ctx = ToolCallContext::new(ToolCallId::new("c1"));
920        let result = tool
921            .execute(json!({"message": "hello"}), &ctx)
922            .await
923            .unwrap();
924        assert_eq!(result.content, "hello");
925        assert!(!result.is_error);
926    }
927
928    #[tokio::test]
929    async fn test_tool_echo_default_concurrency() {
930        let tool = EchoTool;
931        assert_eq!(tool.concurrency_mode(), ConcurrencyMode::Shared);
932    }
933
934    #[tokio::test]
935    async fn test_tool_echo_default_validate() {
936        let tool = EchoTool;
937        let ctx = ToolCallContext::new(ToolCallId::new("c1"));
938        // 默认 validate 应通过
939        assert!(tool.validate(&json!({}), &ctx).await.is_ok());
940    }
941
942    #[tokio::test]
943    async fn test_tool_definition_matches() {
944        let tool = EchoTool;
945        assert_eq!(tool.definition().name, "echo");
946        assert!(!tool.definition().description.is_empty());
947    }
948
949    #[tokio::test]
950    async fn test_tool_exclusive_concurrency() {
951        let tool = ExclusiveTool;
952        assert_eq!(tool.concurrency_mode(), ConcurrencyMode::Exclusive);
953    }
954
955    #[tokio::test]
956    async fn test_tool_validate_rejects_invalid() {
957        let tool = ExclusiveTool;
958        let ctx = ToolCallContext::new(ToolCallId::new("c2"));
959        let result = tool
960            .validate(&json!({"path": "/etc/shadow", "content": "x"}), &ctx)
961            .await;
962        assert!(result.is_err());
963    }
964
965    #[tokio::test]
966    async fn test_tool_validate_accepts_valid() {
967        let tool = ExclusiveTool;
968        let ctx = ToolCallContext::new(ToolCallId::new("c3"));
969        let result = tool
970            .validate(&json!({"path": "/tmp/test.txt", "content": "x"}), &ctx)
971            .await;
972        assert!(result.is_ok());
973    }
974
975    #[tokio::test]
976    async fn test_tool_execute_with_cancellation() {
977        let tool = ExclusiveTool;
978        let token = CancellationToken::new();
979        let ctx = ToolCallContext::new(ToolCallId::new("c4"))
980            .with_cancellation(token.clone());
981
982        // 未取消 → 成功
983        let result = tool
984            .execute(json!({"path": "/tmp/a.txt", "content": "hi"}), &ctx)
985            .await
986            .unwrap();
987        assert_eq!(result.content, "wrote to /tmp/a.txt");
988
989        // 取消后 → 错误
990        token.cancel();
991        let ctx2 = ToolCallContext::new(ToolCallId::new("c5"))
992            .with_cancellation(token.clone());
993        let result = tool
994            .execute(json!({"path": "/tmp/b.txt", "content": "hi"}), &ctx2)
995            .await;
996        assert!(matches!(result, Err(crate::Error::Cancelled)));
997    }
998
999    #[tokio::test]
1000    async fn test_tool_dyn_dispatch() {
1001        // 验证 Tool trait 支持 dyn dispatch(object safety)
1002        let tool: Arc<dyn Tool> = Arc::new(EchoTool);
1003        let ctx = ToolCallContext::new(ToolCallId::new("c6"));
1004        let result = tool
1005            .execute(json!({"message": "dynamic"}), &ctx)
1006            .await
1007            .unwrap();
1008        assert_eq!(result.content, "dynamic");
1009    }
1010}