Skip to main content

katu_core/
hook.rs

1//! # katu_core::hook
2//!
3//! ## 职责
4//! 定义 Hook 系统的类型与执行契约 — Agent 生命周期的可拦截节点。
5//!
6//! ## 设计
7//! Hook 系统与 `AgentEvent` 互补:
8//! - `AgentEvent`(agent_event.rs)= 不可变的**已发生事实**,消费者只能观察
9//! - `Hook`(本模块)= 可拦截的**执行节点**,Hook 可以拦截、修改、阻止
10//!
11//! ```text
12//! Agent Loop ──节点──► HookRegistry.run() ──决策──► 继续/修改/阻止
13//!                                                        │
14//!                                                        ▼
15//!                                              AgentEvent(记录事实)
16//! ```
17//!
18//! ## 对外接口
19//! - `HookEvent` — 可拦截的生命周期节点(10 种)
20//! - `HookInput` — 各事件的类型化输入
21//! - `HookOutput` — Hook 的决策结果
22//! - `HookPermission` — 权限决策(allow / deny / ask)
23//! - `Hook` — 执行 trait(async, object-safe)
24//! - `HookSource` — Hook 来源标识
25//! - `HookRegistry` — 注册与匹配
26//! - `AggregatedHookOutput` — 多 Hook 结果聚合
27//!
28//! ## 调用者
29//! - `katu-agent` (future) — Agent loop 在各生命周期节点调用 Hook
30//! - 应用层 — 注册自定义 Hook 实现
31
32use std::sync::Arc;
33
34use async_trait::async_trait;
35use serde::{Deserialize, Serialize};
36
37use crate::tool::ToolOutput;
38use crate::types::{SessionId, ToolCallId};
39
40// ===========================================================================
41// HookEvent
42// ===========================================================================
43
44/// 可拦截的 Agent 生命周期节点。
45///
46/// 与 `AgentEvent`(28 种只读观察事件)互补,`HookEvent` 只覆盖
47/// **需要干预能力**的节点 — 拦截、修改输入/输出、权限决策。
48///
49/// # Examples
50///
51/// ```
52/// use katu_core::hook::HookEvent;
53///
54/// let event = HookEvent::PreToolUse;
55/// assert!(event.is_tool_event());
56/// assert!(!HookEvent::SessionStart.is_tool_event());
57/// ```
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum HookEvent {
61    // ── Tool 生命周期 ─────────────────────────────────────
62
63    /// 工具执行前 — 可以 allow/deny/ask、修改 input。
64    PreToolUse,
65    /// 工具执行成功后 — 可以注入上下文、修改输出。
66    PostToolUse,
67    /// 工具执行失败后 — 可以注入诊断上下文。
68    PostToolFailure,
69
70    // ── 用户交互 ──────────────────────────────────────────
71
72    /// 用户提交 prompt 前 — 可以注入上下文或拦截。
73    UserPromptSubmit,
74
75    // ── Session 生命周期 ──────────────────────────────────
76
77    /// Session 开始。
78    SessionStart,
79    /// Session 结束。
80    SessionEnd,
81
82    // ── Agent 生命周期 ────────────────────────────────────
83
84    /// Agent loop 单步结束判定 — 可以阻止停止、要求继续。
85    Stop,
86    /// SubAgent 启动前。
87    SubAgentStart,
88
89    // ── Compaction ────────────────────────────────────────
90
91    /// 上下文压缩前。
92    PreCompact,
93    /// 上下文压缩后。
94    PostCompact,
95}
96
97/// 所有 Hook 事件的完整列表,按定义顺序。
98pub const ALL_HOOK_EVENTS: &[HookEvent] = &[
99    HookEvent::PreToolUse,
100    HookEvent::PostToolUse,
101    HookEvent::PostToolFailure,
102    HookEvent::UserPromptSubmit,
103    HookEvent::SessionStart,
104    HookEvent::SessionEnd,
105    HookEvent::Stop,
106    HookEvent::SubAgentStart,
107    HookEvent::PreCompact,
108    HookEvent::PostCompact,
109];
110
111impl HookEvent {
112    /// 是否为 Tool 相关事件。
113    pub fn is_tool_event(&self) -> bool {
114        matches!(
115            self,
116            Self::PreToolUse | Self::PostToolUse | Self::PostToolFailure
117        )
118    }
119}
120
121impl std::fmt::Display for HookEvent {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        let s = match self {
124            Self::PreToolUse => "pre_tool_use",
125            Self::PostToolUse => "post_tool_use",
126            Self::PostToolFailure => "post_tool_failure",
127            Self::UserPromptSubmit => "user_prompt_submit",
128            Self::SessionStart => "session_start",
129            Self::SessionEnd => "session_end",
130            Self::Stop => "stop",
131            Self::SubAgentStart => "sub_agent_start",
132            Self::PreCompact => "pre_compact",
133            Self::PostCompact => "post_compact",
134        };
135        f.write_str(s)
136    }
137}
138
139// ===========================================================================
140// HookInput
141// ===========================================================================
142
143/// Hook 输入 — 每个 HookEvent 携带的上下文数据。
144///
145/// 使用 enum 确保类型安全,Hook 实现方通过 match 获取精确类型。
146/// 所有变体均可序列化,支持跨进程 Hook(如 shell command hook)。
147///
148/// # Examples
149///
150/// ```
151/// use katu_core::hook::{HookEvent, HookInput};
152/// use katu_core::ToolCallId;
153/// use serde_json::json;
154///
155/// let input = HookInput::PreToolUse {
156///     tool_name: "bash".into(),
157///     tool_input: json!({"command": "ls -la"}),
158///     call_id: ToolCallId::new("call_1"),
159/// };
160/// assert_eq!(input.event(), HookEvent::PreToolUse);
161/// assert_eq!(input.tool_name(), Some("bash"));
162/// ```
163#[derive(Debug, Clone, Serialize, Deserialize)]
164#[serde(tag = "hook_event", rename_all = "snake_case")]
165pub enum HookInput {
166    /// 工具执行前。
167    PreToolUse {
168        tool_name: String,
169        tool_input: serde_json::Value,
170        call_id: ToolCallId,
171    },
172
173    /// 工具执行成功后。
174    PostToolUse {
175        tool_name: String,
176        tool_input: serde_json::Value,
177        tool_output: ToolOutput,
178        call_id: ToolCallId,
179    },
180
181    /// 工具执行失败后。
182    PostToolFailure {
183        tool_name: String,
184        tool_input: serde_json::Value,
185        error: String,
186        call_id: ToolCallId,
187    },
188
189    /// 用户提交 prompt 前。
190    UserPromptSubmit {
191        prompt: String,
192    },
193
194    /// Session 开始。
195    SessionStart {
196        session_id: SessionId,
197    },
198
199    /// Session 结束。
200    SessionEnd {
201        session_id: SessionId,
202        reason: String,
203    },
204
205    /// Agent loop 单步结束判定。
206    Stop {
207        finish_reason: String,
208    },
209
210    /// SubAgent 启动前。
211    SubAgentStart {
212        agent_name: String,
213    },
214
215    /// 上下文压缩前。
216    PreCompact {
217        trigger: String,
218        tokens_before: u64,
219    },
220
221    /// 上下文压缩后。
222    PostCompact {
223        trigger: String,
224        tokens_after: u64,
225    },
226}
227
228impl HookInput {
229    /// 返回此输入对应的事件类型。
230    pub fn event(&self) -> HookEvent {
231        match self {
232            Self::PreToolUse { .. } => HookEvent::PreToolUse,
233            Self::PostToolUse { .. } => HookEvent::PostToolUse,
234            Self::PostToolFailure { .. } => HookEvent::PostToolFailure,
235            Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
236            Self::SessionStart { .. } => HookEvent::SessionStart,
237            Self::SessionEnd { .. } => HookEvent::SessionEnd,
238            Self::Stop { .. } => HookEvent::Stop,
239            Self::SubAgentStart { .. } => HookEvent::SubAgentStart,
240            Self::PreCompact { .. } => HookEvent::PreCompact,
241            Self::PostCompact { .. } => HookEvent::PostCompact,
242        }
243    }
244
245    /// 如果是 Tool 相关事件,返回工具名。
246    pub fn tool_name(&self) -> Option<&str> {
247        match self {
248            Self::PreToolUse { tool_name, .. }
249            | Self::PostToolUse { tool_name, .. }
250            | Self::PostToolFailure { tool_name, .. } => Some(tool_name.as_str()),
251            _ => None,
252        }
253    }
254
255    /// 如果是 Tool 相关事件,返回 call_id。
256    pub fn call_id(&self) -> Option<&ToolCallId> {
257        match self {
258            Self::PreToolUse { call_id, .. }
259            | Self::PostToolUse { call_id, .. }
260            | Self::PostToolFailure { call_id, .. } => Some(call_id),
261            _ => None,
262        }
263    }
264}
265
266// ===========================================================================
267// HookPermission
268// ===========================================================================
269
270/// Hook 的权限决策 — 仅对 `PreToolUse` 事件有意义。
271///
272/// 多个 Hook 的权限按严格度聚合:`Deny > Ask > Allow`。
273///
274/// # 重要
275/// Hook 返回 `Allow` **不能**绕过 settings 中的 deny 规则。
276/// Agent loop 应在 Hook 决策之后再检查规则级权限。
277///
278/// # Examples
279///
280/// ```
281/// use katu_core::hook::HookPermission;
282///
283/// let deny = HookPermission::Deny { reason: Some("unsafe command".into()) };
284/// let allow = HookPermission::Allow;
285/// assert!(deny.is_deny());
286/// assert!(!allow.is_deny());
287/// ```
288#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
289#[serde(tag = "type", rename_all = "snake_case")]
290pub enum HookPermission {
291    /// 允许执行(可被 settings deny 规则覆盖)。
292    Allow,
293
294    /// 拒绝执行。
295    Deny {
296        #[serde(default, skip_serializing_if = "Option::is_none")]
297        reason: Option<String>,
298    },
299
300    /// 需要用户确认。
301    Ask {
302        #[serde(default, skip_serializing_if = "Option::is_none")]
303        message: Option<String>,
304    },
305}
306
307impl HookPermission {
308    /// 创建无原因的 Deny。
309    pub fn deny() -> Self {
310        Self::Deny { reason: None }
311    }
312
313    /// 创建带原因的 Deny。
314    pub fn deny_with_reason(reason: impl Into<String>) -> Self {
315        Self::Deny {
316            reason: Some(reason.into()),
317        }
318    }
319
320    /// 创建无消息的 Ask。
321    pub fn ask() -> Self {
322        Self::Ask { message: None }
323    }
324
325    /// 创建带消息的 Ask。
326    pub fn ask_with_message(message: impl Into<String>) -> Self {
327        Self::Ask {
328            message: Some(message.into()),
329        }
330    }
331
332    pub fn is_allow(&self) -> bool {
333        matches!(self, Self::Allow)
334    }
335
336    pub fn is_deny(&self) -> bool {
337        matches!(self, Self::Deny { .. })
338    }
339
340    pub fn is_ask(&self) -> bool {
341        matches!(self, Self::Ask { .. })
342    }
343
344    /// 返回严格度数值 — 用于聚合时比较优先级。
345    ///
346    /// Deny(2) > Ask(1) > Allow(0)。
347    fn strictness(&self) -> u8 {
348        match self {
349            Self::Allow => 0,
350            Self::Ask { .. } => 1,
351            Self::Deny { .. } => 2,
352        }
353    }
354}
355
356// ===========================================================================
357// HookOutput
358// ===========================================================================
359
360/// Hook 的执行结果 — 告诉 Agent loop 如何继续。
361///
362/// ## 设计要点
363/// - `Default` = 无操作(passthrough),不影响正常流程
364/// - 多个字段独立,可同时设置(如 allow + 注入 context)
365/// - 权限聚合优先级由 `AggregatedHookOutput` 处理
366///
367/// # Examples
368///
369/// ```
370/// use katu_core::hook::HookOutput;
371///
372/// // passthrough — 什么都不做
373/// let out = HookOutput::passthrough();
374/// assert!(!out.has_decision());
375///
376/// // deny + context
377/// let out = HookOutput::deny("dangerous")
378///     .with_context("This command modifies system files");
379/// assert!(out.has_decision());
380/// ```
381#[derive(Debug, Clone, Default, Serialize, Deserialize)]
382pub struct HookOutput {
383    /// 权限决策 — 仅对 PreToolUse 有意义。
384    #[serde(default, skip_serializing_if = "Option::is_none")]
385    pub permission: Option<HookPermission>,
386
387    /// 修改后的工具输入 — PreToolUse 时替换原始 input。
388    #[serde(default, skip_serializing_if = "Option::is_none")]
389    pub updated_input: Option<serde_json::Value>,
390
391    /// 修改后的工具输出 — PostToolUse 时替换原始 output。
392    #[serde(default, skip_serializing_if = "Option::is_none")]
393    pub updated_output: Option<ToolOutput>,
394
395    /// 注入给 LLM 的额外上下文。
396    #[serde(default, skip_serializing_if = "Vec::is_empty")]
397    pub additional_context: Vec<String>,
398
399    /// 是否阻止 Agent 继续执行。
400    #[serde(default)]
401    pub prevent_continuation: bool,
402
403    /// 阻止原因(`prevent_continuation = true` 时展示给用户)。
404    #[serde(default, skip_serializing_if = "Option::is_none")]
405    pub stop_reason: Option<String>,
406
407    /// 阻塞性错误 — 反馈给 model 的错误消息。
408    #[serde(default, skip_serializing_if = "Option::is_none")]
409    pub blocking_error: Option<String>,
410
411    /// 系统消息 — 展示给用户的提示/警告。
412    #[serde(default, skip_serializing_if = "Option::is_none")]
413    pub system_message: Option<String>,
414}
415
416impl HookOutput {
417    /// 无操作 — 不影响正常流程。
418    pub fn passthrough() -> Self {
419        Self::default()
420    }
421
422    /// 允许执行。
423    pub fn allow() -> Self {
424        Self {
425            permission: Some(HookPermission::Allow),
426            ..Default::default()
427        }
428    }
429
430    /// 拒绝执行。
431    pub fn deny(reason: impl Into<String>) -> Self {
432        Self {
433            permission: Some(HookPermission::deny_with_reason(reason)),
434            ..Default::default()
435        }
436    }
437
438    /// 需要用户确认。
439    pub fn ask(message: impl Into<String>) -> Self {
440        Self {
441            permission: Some(HookPermission::ask_with_message(message)),
442            ..Default::default()
443        }
444    }
445
446    /// 设置修改后的工具输入(builder 模式)。
447    pub fn with_updated_input(mut self, input: serde_json::Value) -> Self {
448        self.updated_input = Some(input);
449        self
450    }
451
452    /// 设置修改后的工具输出(builder 模式)。
453    pub fn with_updated_output(mut self, output: ToolOutput) -> Self {
454        self.updated_output = Some(output);
455        self
456    }
457
458    /// 追加额外上下文(builder 模式)。
459    pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
460        self.additional_context.push(ctx.into());
461        self
462    }
463
464    /// 阻止 Agent 继续执行(builder 模式)。
465    pub fn with_stop(mut self, reason: impl Into<String>) -> Self {
466        self.prevent_continuation = true;
467        self.stop_reason = Some(reason.into());
468        self
469    }
470
471    /// 设置阻塞性错误(builder 模式)。
472    pub fn with_blocking_error(mut self, error: impl Into<String>) -> Self {
473        self.blocking_error = Some(error.into());
474        self
475    }
476
477    /// 设置系统消息(builder 模式)。
478    pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
479        self.system_message = Some(message.into());
480        self
481    }
482
483    /// 是否做出了实质性决策(非 passthrough)。
484    pub fn has_decision(&self) -> bool {
485        self.permission.is_some()
486            || self.updated_input.is_some()
487            || self.updated_output.is_some()
488            || !self.additional_context.is_empty()
489            || self.prevent_continuation
490            || self.blocking_error.is_some()
491    }
492}
493
494// ===========================================================================
495// Hook trait
496// ===========================================================================
497
498/// Hook 执行 trait — 所有可注册到 Agent loop 的 Hook 必须实现。
499///
500/// ## 设计选择
501/// - **`on_event`** — 处理所有事件类型,Hook 自行 match 感兴趣的事件
502/// - **`HookOutput`** — 默认 passthrough,显式 opt-in 干预
503/// - **async** — 支持异步操作(网络请求、LLM 查询等)
504/// - **`&self`** — 无状态偏好;需要状态的 Hook 使用内部可变性
505///
506/// ## Object Safety
507/// 通过 `#[async_trait]` 实现 dyn dispatch,支持 `Arc<dyn Hook>` 存储。
508///
509/// # Examples
510///
511/// ```
512/// use async_trait::async_trait;
513/// use katu_core::hook::*;
514///
515/// struct DangerousCommandBlocker;
516///
517/// #[async_trait]
518/// impl Hook for DangerousCommandBlocker {
519///     fn name(&self) -> &str { "dangerous_command_blocker" }
520///
521///     fn events(&self) -> &[HookEvent] {
522///         &[HookEvent::PreToolUse]
523///     }
524///
525///     fn matcher(&self) -> Option<&str> {
526///         Some("bash")
527///     }
528///
529///     async fn on_event(&self, input: &HookInput) -> HookOutput {
530///         if let HookInput::PreToolUse { tool_input, .. } = input {
531///             let cmd = tool_input["command"].as_str().unwrap_or("");
532///             if cmd.contains("rm -rf /") {
533///                 return HookOutput::deny("Refusing to delete root filesystem");
534///             }
535///         }
536///         HookOutput::passthrough()
537///     }
538/// }
539/// ```
540#[async_trait]
541pub trait Hook: Send + Sync {
542    /// Hook 名称 — 用于日志、诊断和去重。
543    fn name(&self) -> &str;
544
545    /// 声明此 Hook 关注的事件列表。
546    ///
547    /// `HookRegistry` 据此过滤,只对匹配的事件调用 `on_event`。
548    /// 返回空切片表示关注所有事件。
549    fn events(&self) -> &[HookEvent] {
550        &[]
551    }
552
553    /// Matcher 模式 — 进一步过滤匹配条件。
554    ///
555    /// 对 Tool 相关事件匹配 `tool_name`,支持:
556    /// - 精确匹配:`"bash"`
557    /// - 管道分隔多选:`"bash|write_file"`
558    /// - 通配符:`"read_*"`
559    ///
560    /// `None` 表示不过滤(匹配所有)。
561    fn matcher(&self) -> Option<&str> {
562        None
563    }
564
565    /// 执行 Hook 逻辑。
566    ///
567    /// 返回 `HookOutput::passthrough()` 表示不干预。
568    async fn on_event(&self, input: &HookInput) -> HookOutput;
569}
570
571// ===========================================================================
572// HookSource
573// ===========================================================================
574
575/// Hook 来源 — 用于冲突解决、日志追踪和安全策略。
576///
577/// # Examples
578///
579/// ```
580/// use katu_core::hook::HookSource;
581///
582/// let src = HookSource::Plugin { name: "linter".into() };
583/// assert!(matches!(src, HookSource::Plugin { .. }));
584/// ```
585#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
586#[serde(tag = "type", rename_all = "snake_case")]
587pub enum HookSource {
588    /// 用户全局配置文件。
589    Settings,
590    /// 项目配置文件。
591    Project,
592    /// Plugin 注册。
593    Plugin { name: String },
594    /// SDK / 程序化注册。
595    Programmatic,
596    /// Session 级别临时注册。
597    Session,
598}
599
600// ===========================================================================
601// HookRegistry
602// ===========================================================================
603
604/// Hook 注册中心 — 管理已注册的 Hook 并按事件/matcher 分发。
605///
606/// ## 生命周期
607/// - Agent 启动时构建(从配置 + 程序化注册)
608/// - Agent loop 在各生命周期节点调用匹配的 Hook
609/// - Session 结束时销毁
610///
611/// ## 线程安全
612/// 使用 `Arc<dyn Hook>` 存储,支持跨 await 共享。
613/// `HookRegistry` 本身在构建后作为 `Arc<HookRegistry>` 或引用传入 agent loop。
614///
615/// # Examples
616///
617/// ```
618/// use std::sync::Arc;
619/// use async_trait::async_trait;
620/// use katu_core::hook::*;
621///
622/// struct MyHook;
623///
624/// #[async_trait]
625/// impl Hook for MyHook {
626///     fn name(&self) -> &str { "my_hook" }
627///     fn events(&self) -> &[HookEvent] { &[HookEvent::PreToolUse] }
628///     async fn on_event(&self, _input: &HookInput) -> HookOutput {
629///         HookOutput::passthrough()
630///     }
631/// }
632///
633/// let mut registry = HookRegistry::new();
634/// registry.register(Arc::new(MyHook), HookSource::Programmatic, 0);
635/// assert_eq!(registry.len(), 1);
636/// ```
637pub struct HookRegistry {
638    hooks: Vec<RegisteredHook>,
639}
640
641/// 已注册的 Hook 条目 — 包含 Hook 实例、来源和优先级。
642pub struct RegisteredHook {
643    /// Hook 实例。
644    pub hook: Arc<dyn Hook>,
645    /// 来源标识。
646    pub source: HookSource,
647    /// 优先级(数值越小越先执行)。
648    pub priority: i32,
649}
650
651impl HookRegistry {
652    /// 创建空的注册中心。
653    pub fn new() -> Self {
654        Self { hooks: Vec::new() }
655    }
656
657    /// 注册一个 Hook。
658    ///
659    /// Hook 按 `priority` 升序排列(数值越小越先执行)。
660    pub fn register(
661        &mut self,
662        hook: Arc<dyn Hook>,
663        source: HookSource,
664        priority: i32,
665    ) {
666        self.hooks.push(RegisteredHook {
667            hook,
668            source,
669            priority,
670        });
671        self.hooks.sort_by_key(|h| h.priority);
672    }
673
674    /// 移除指定名称的所有 Hook。
675    pub fn remove(&mut self, name: &str) {
676        self.hooks.retain(|h| h.hook.name() != name);
677    }
678
679    /// 已注册的 Hook 数量。
680    pub fn len(&self) -> usize {
681        self.hooks.len()
682    }
683
684    /// 是否为空。
685    pub fn is_empty(&self) -> bool {
686        self.hooks.is_empty()
687    }
688
689    /// 获取匹配指定输入的所有 Hook(按 priority 排序)。
690    ///
691    /// 匹配逻辑:
692    /// 1. 事件匹配 — `hook.events()` 为空(匹配所有)或包含 `input.event()`
693    /// 2. Matcher 匹配 — `hook.matcher()` 为 None(匹配所有)或模式匹配 tool_name
694    pub fn matching(&self, input: &HookInput) -> Vec<&RegisteredHook> {
695        let event = input.event();
696        let tool_name = input.tool_name();
697
698        self.hooks
699            .iter()
700            .filter(|h| {
701                let events = h.hook.events();
702                let event_match = events.is_empty() || events.contains(&event);
703                if !event_match {
704                    return false;
705                }
706
707                match (h.hook.matcher(), tool_name) {
708                    (Some(pattern), Some(name)) => matches_pattern(name, pattern),
709                    (Some(_), None) => false,
710                    (None, _) => true,
711                }
712            })
713            .collect()
714    }
715
716    /// 检查是否有任何 Hook 注册在指定事件上。
717    ///
718    /// 轻量级检查,不做 matcher 匹配,用于热路径的快速跳过。
719    pub fn has_hooks_for(&self, event: HookEvent) -> bool {
720        self.hooks.iter().any(|h| {
721            let events = h.hook.events();
722            events.is_empty() || events.contains(&event)
723        })
724    }
725}
726
727impl Default for HookRegistry {
728    fn default() -> Self {
729        Self::new()
730    }
731}
732
733// ===========================================================================
734// AggregatedHookOutput
735// ===========================================================================
736
737/// 多个 Hook 结果的聚合 — 并行执行后合并。
738///
739/// ## 聚合规则
740/// - **权限**:`Deny > Ask > Allow`(最严格的获胜)
741/// - **updated_input / updated_output**:最后一个设置者获胜
742/// - **additional_context / blocking_errors / system_messages**:合并
743/// - **prevent_continuation**:任一 Hook 设置即生效
744///
745/// # Examples
746///
747/// ```
748/// use katu_core::hook::{AggregatedHookOutput, HookOutput, HookPermission};
749///
750/// let mut agg = AggregatedHookOutput::default();
751///
752/// // Hook A: allow
753/// agg.merge(HookOutput::allow(), "hook_a");
754/// assert_eq!(agg.permission, Some(HookPermission::Allow));
755///
756/// // Hook B: deny(更严格,覆盖 allow)
757/// agg.merge(HookOutput::deny("unsafe"), "hook_b");
758/// assert!(agg.permission.as_ref().unwrap().is_deny());
759/// ```
760#[derive(Debug, Clone, Default)]
761pub struct AggregatedHookOutput {
762    /// 聚合后的权限决策(最严格的获胜)。
763    pub permission: Option<HookPermission>,
764
765    /// 最终的 updated_input(最后一个设置者获胜)。
766    pub updated_input: Option<serde_json::Value>,
767
768    /// 最终的 updated_output。
769    pub updated_output: Option<ToolOutput>,
770
771    /// 所有 Hook 注入的上下文(合并)。
772    pub additional_context: Vec<String>,
773
774    /// 任一 Hook 阻止继续。
775    pub prevent_continuation: bool,
776
777    /// 第一个 stop_reason。
778    pub stop_reason: Option<String>,
779
780    /// 所有 blocking_error(合并,带 hook 名前缀)。
781    pub blocking_errors: Vec<String>,
782
783    /// 所有 system_message(合并)。
784    pub system_messages: Vec<String>,
785}
786
787impl AggregatedHookOutput {
788    /// 合并单个 Hook 的输出。
789    pub fn merge(&mut self, output: HookOutput, hook_name: &str) {
790        // 权限聚合:Deny > Ask > Allow
791        if let Some(ref new_perm) = output.permission {
792            match &self.permission {
793                Some(existing) if existing.strictness() >= new_perm.strictness() => {
794                    // 已有更严格或同级的决策,保持不变
795                }
796                _ => {
797                    self.permission = output.permission.clone();
798                }
799            }
800        }
801
802        if output.updated_input.is_some() {
803            self.updated_input = output.updated_input;
804        }
805        if output.updated_output.is_some() {
806            self.updated_output = output.updated_output;
807        }
808        self.additional_context.extend(output.additional_context);
809
810        if output.prevent_continuation {
811            self.prevent_continuation = true;
812            if self.stop_reason.is_none() {
813                self.stop_reason = output.stop_reason;
814            }
815        }
816
817        if let Some(err) = output.blocking_error {
818            self.blocking_errors.push(format!("[{hook_name}] {err}"));
819        }
820        if let Some(msg) = output.system_message {
821            self.system_messages.push(msg);
822        }
823    }
824
825    /// 是否有任何 Hook 做出了实质性决策。
826    pub fn has_decision(&self) -> bool {
827        self.permission.is_some()
828            || self.updated_input.is_some()
829            || self.updated_output.is_some()
830            || !self.additional_context.is_empty()
831            || self.prevent_continuation
832            || !self.blocking_errors.is_empty()
833    }
834
835    /// 是否有阻塞性错误。
836    pub fn has_blocking_errors(&self) -> bool {
837        !self.blocking_errors.is_empty()
838    }
839
840    /// 是否被拒绝。
841    pub fn is_denied(&self) -> bool {
842        matches!(&self.permission, Some(p) if p.is_deny())
843    }
844}
845
846// ===========================================================================
847// 工具函数
848// ===========================================================================
849
850/// 模式匹配 — 判断 `value` 是否匹配 `pattern`。
851///
852/// 支持三种语法:
853/// - 精确匹配:`"bash"` 匹配 `"bash"`
854/// - 管道分隔多选:`"bash|write_file"` 匹配 `"bash"` 或 `"write_file"`
855/// - 通配符 `*`:`"read_*"` 匹配 `"read_file"`, `"read_dir"` 等
856///
857/// # Examples
858///
859/// ```
860/// use katu_core::hook::matches_pattern;
861///
862/// assert!(matches_pattern("bash", "bash"));
863/// assert!(matches_pattern("bash", "bash|write_file"));
864/// assert!(matches_pattern("read_file", "read_*"));
865/// assert!(!matches_pattern("write_file", "read_*"));
866/// ```
867pub fn matches_pattern(value: &str, pattern: &str) -> bool {
868    if pattern.contains('|') {
869        return pattern.split('|').any(|p| matches_single_pattern(value, p.trim()));
870    }
871    matches_single_pattern(value, pattern)
872}
873
874/// 单个模式匹配(支持 `*` 通配符)。
875fn matches_single_pattern(value: &str, pattern: &str) -> bool {
876    if !pattern.contains('*') {
877        return value == pattern;
878    }
879
880    let parts: Vec<&str> = pattern.split('*').collect();
881
882    // 单个 `*` → 匹配所有
883    if parts.len() == 2 && parts[0].is_empty() && parts[1].is_empty() {
884        return true;
885    }
886
887    // 前缀匹配:`read_*`
888    if parts.len() == 2 && parts[1].is_empty() {
889        return value.starts_with(parts[0]);
890    }
891
892    // 后缀匹配:`*_file`
893    if parts.len() == 2 && parts[0].is_empty() {
894        return value.ends_with(parts[1]);
895    }
896
897    // 前后匹配:`pre_*_use`
898    if parts.len() == 2 {
899        return value.starts_with(parts[0])
900            && value.ends_with(parts[1])
901            && value.len() >= parts[0].len() + parts[1].len();
902    }
903
904    // 多段通配符:逐段贪心匹配
905    let mut remaining = value;
906    for (i, part) in parts.iter().enumerate() {
907        if part.is_empty() {
908            continue;
909        }
910        if i == 0 {
911            if !remaining.starts_with(part) {
912                return false;
913            }
914            remaining = &remaining[part.len()..];
915        } else if let Some(pos) = remaining.find(part) {
916            remaining = &remaining[pos + part.len()..];
917        } else {
918            return false;
919        }
920    }
921    true
922}
923
924// ===========================================================================
925// Tests
926// ===========================================================================
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931    use serde_json::json;
932
933    // -- HookEvent --
934
935    #[test]
936    fn test_hook_event_is_tool_event() {
937        assert!(HookEvent::PreToolUse.is_tool_event());
938        assert!(HookEvent::PostToolUse.is_tool_event());
939        assert!(HookEvent::PostToolFailure.is_tool_event());
940        assert!(!HookEvent::SessionStart.is_tool_event());
941        assert!(!HookEvent::Stop.is_tool_event());
942    }
943
944    #[test]
945    fn test_hook_event_display() {
946        assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
947        assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
948        assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
949    }
950
951    #[test]
952    fn test_hook_event_serde_roundtrip() {
953        for event in ALL_HOOK_EVENTS {
954            let json_str = serde_json::to_string(event).unwrap();
955            let restored: HookEvent = serde_json::from_str(&json_str).unwrap();
956            assert_eq!(*event, restored);
957        }
958    }
959
960    #[test]
961    fn test_all_hook_events_count() {
962        assert_eq!(ALL_HOOK_EVENTS.len(), 10);
963    }
964
965    // -- HookInput --
966
967    #[test]
968    fn test_hook_input_event() {
969        let input = HookInput::PreToolUse {
970            tool_name: "bash".into(),
971            tool_input: json!({}),
972            call_id: ToolCallId::new("c1"),
973        };
974        assert_eq!(input.event(), HookEvent::PreToolUse);
975    }
976
977    #[test]
978    fn test_hook_input_tool_name() {
979        let tool_input = HookInput::PreToolUse {
980            tool_name: "bash".into(),
981            tool_input: json!({}),
982            call_id: ToolCallId::new("c1"),
983        };
984        assert_eq!(tool_input.tool_name(), Some("bash"));
985
986        let non_tool = HookInput::SessionStart {
987            session_id: SessionId::new(),
988        };
989        assert_eq!(non_tool.tool_name(), None);
990    }
991
992    #[test]
993    fn test_hook_input_call_id() {
994        let input = HookInput::PostToolFailure {
995            tool_name: "bash".into(),
996            tool_input: json!({}),
997            error: "exit code 1".into(),
998            call_id: ToolCallId::new("c2"),
999        };
1000        assert_eq!(input.call_id().unwrap().as_str(), "c2");
1001
1002        let non_tool = HookInput::Stop {
1003            finish_reason: "completed".into(),
1004        };
1005        assert!(non_tool.call_id().is_none());
1006    }
1007
1008    #[test]
1009    fn test_hook_input_serde_roundtrip() {
1010        let input = HookInput::PreToolUse {
1011            tool_name: "read_file".into(),
1012            tool_input: json!({"path": "/tmp/test.txt"}),
1013            call_id: ToolCallId::new("call_42"),
1014        };
1015        let json_str = serde_json::to_string(&input).unwrap();
1016        assert!(json_str.contains("pre_tool_use"));
1017        let restored: HookInput = serde_json::from_str(&json_str).unwrap();
1018        assert_eq!(restored.event(), HookEvent::PreToolUse);
1019        assert_eq!(restored.tool_name(), Some("read_file"));
1020    }
1021
1022    // -- HookPermission --
1023
1024    #[test]
1025    fn test_hook_permission_variants() {
1026        assert!(HookPermission::Allow.is_allow());
1027        assert!(HookPermission::deny().is_deny());
1028        assert!(HookPermission::ask().is_ask());
1029    }
1030
1031    #[test]
1032    fn test_hook_permission_with_reason() {
1033        let deny = HookPermission::deny_with_reason("unsafe");
1034        match deny {
1035            HookPermission::Deny { reason } => assert_eq!(reason, Some("unsafe".into())),
1036            _ => panic!("expected Deny"),
1037        }
1038    }
1039
1040    #[test]
1041    fn test_hook_permission_strictness() {
1042        assert!(HookPermission::Allow.strictness() < HookPermission::ask().strictness());
1043        assert!(HookPermission::ask().strictness() < HookPermission::deny().strictness());
1044    }
1045
1046    #[test]
1047    fn test_hook_permission_serde_roundtrip() {
1048        for perm in [
1049            HookPermission::Allow,
1050            HookPermission::deny(),
1051            HookPermission::deny_with_reason("test"),
1052            HookPermission::ask(),
1053            HookPermission::ask_with_message("confirm?"),
1054        ] {
1055            let json_str = serde_json::to_string(&perm).unwrap();
1056            let restored: HookPermission = serde_json::from_str(&json_str).unwrap();
1057            assert_eq!(perm, restored);
1058        }
1059    }
1060
1061    // -- HookOutput --
1062
1063    #[test]
1064    fn test_hook_output_passthrough() {
1065        let out = HookOutput::passthrough();
1066        assert!(!out.has_decision());
1067        assert!(out.permission.is_none());
1068        assert!(out.additional_context.is_empty());
1069    }
1070
1071    #[test]
1072    fn test_hook_output_allow() {
1073        let out = HookOutput::allow();
1074        assert!(out.has_decision());
1075        assert!(out.permission.as_ref().unwrap().is_allow());
1076    }
1077
1078    #[test]
1079    fn test_hook_output_deny() {
1080        let out = HookOutput::deny("bad command");
1081        assert!(out.has_decision());
1082        assert!(out.permission.as_ref().unwrap().is_deny());
1083    }
1084
1085    #[test]
1086    fn test_hook_output_ask() {
1087        let out = HookOutput::ask("are you sure?");
1088        assert!(out.has_decision());
1089        assert!(out.permission.as_ref().unwrap().is_ask());
1090    }
1091
1092    #[test]
1093    fn test_hook_output_builder() {
1094        let out = HookOutput::allow()
1095            .with_updated_input(json!({"command": "ls"}))
1096            .with_context("working directory: /tmp")
1097            .with_system_message("Input sanitized");
1098
1099        assert!(out.permission.as_ref().unwrap().is_allow());
1100        assert_eq!(out.updated_input.as_ref().unwrap()["command"], "ls");
1101        assert_eq!(out.additional_context.len(), 1);
1102        assert_eq!(out.system_message, Some("Input sanitized".into()));
1103    }
1104
1105    #[test]
1106    fn test_hook_output_with_stop() {
1107        let out = HookOutput::passthrough().with_stop("loop detected");
1108        assert!(out.prevent_continuation);
1109        assert_eq!(out.stop_reason, Some("loop detected".into()));
1110        assert!(out.has_decision());
1111    }
1112
1113    #[test]
1114    fn test_hook_output_with_blocking_error() {
1115        let out = HookOutput::passthrough().with_blocking_error("lint failed");
1116        assert!(out.has_decision());
1117        assert_eq!(out.blocking_error, Some("lint failed".into()));
1118    }
1119
1120    #[test]
1121    fn test_hook_output_serde_roundtrip() {
1122        let out = HookOutput::deny("test")
1123            .with_context("ctx1")
1124            .with_system_message("msg1");
1125        let json_str = serde_json::to_string(&out).unwrap();
1126        let restored: HookOutput = serde_json::from_str(&json_str).unwrap();
1127        assert_eq!(restored.additional_context, vec!["ctx1"]);
1128        assert_eq!(restored.system_message, Some("msg1".into()));
1129    }
1130
1131    // -- AggregatedHookOutput --
1132
1133    #[test]
1134    fn test_aggregated_merge_permission_deny_wins() {
1135        let mut agg = AggregatedHookOutput::default();
1136
1137        agg.merge(HookOutput::allow(), "hook_a");
1138        assert!(agg.permission.as_ref().unwrap().is_allow());
1139
1140        agg.merge(HookOutput::deny("nope"), "hook_b");
1141        assert!(agg.permission.as_ref().unwrap().is_deny());
1142
1143        // allow after deny — deny 仍然获胜
1144        agg.merge(HookOutput::allow(), "hook_c");
1145        assert!(agg.permission.as_ref().unwrap().is_deny());
1146    }
1147
1148    #[test]
1149    fn test_aggregated_merge_permission_ask_beats_allow() {
1150        let mut agg = AggregatedHookOutput::default();
1151
1152        agg.merge(HookOutput::allow(), "hook_a");
1153        agg.merge(HookOutput::ask("confirm?"), "hook_b");
1154        assert!(agg.permission.as_ref().unwrap().is_ask());
1155
1156        // allow after ask — ask 仍然获胜
1157        agg.merge(HookOutput::allow(), "hook_c");
1158        assert!(agg.permission.as_ref().unwrap().is_ask());
1159    }
1160
1161    #[test]
1162    fn test_aggregated_merge_context() {
1163        let mut agg = AggregatedHookOutput::default();
1164
1165        agg.merge(
1166            HookOutput::passthrough().with_context("ctx1"),
1167            "hook_a",
1168        );
1169        agg.merge(
1170            HookOutput::passthrough().with_context("ctx2"),
1171            "hook_b",
1172        );
1173        assert_eq!(agg.additional_context, vec!["ctx1", "ctx2"]);
1174    }
1175
1176    #[test]
1177    fn test_aggregated_merge_blocking_errors() {
1178        let mut agg = AggregatedHookOutput::default();
1179
1180        agg.merge(
1181            HookOutput::passthrough().with_blocking_error("err1"),
1182            "linter",
1183        );
1184        agg.merge(
1185            HookOutput::passthrough().with_blocking_error("err2"),
1186            "validator",
1187        );
1188        assert_eq!(agg.blocking_errors.len(), 2);
1189        assert!(agg.blocking_errors[0].contains("[linter]"));
1190        assert!(agg.blocking_errors[1].contains("[validator]"));
1191        assert!(agg.has_blocking_errors());
1192    }
1193
1194    #[test]
1195    fn test_aggregated_merge_stop() {
1196        let mut agg = AggregatedHookOutput::default();
1197
1198        agg.merge(HookOutput::passthrough(), "hook_a");
1199        assert!(!agg.prevent_continuation);
1200
1201        agg.merge(
1202            HookOutput::passthrough().with_stop("first reason"),
1203            "hook_b",
1204        );
1205        assert!(agg.prevent_continuation);
1206        assert_eq!(agg.stop_reason, Some("first reason".into()));
1207
1208        // 第二个 stop — prevent_continuation 已为 true,stop_reason 保持第一个
1209        agg.merge(
1210            HookOutput::passthrough().with_stop("second reason"),
1211            "hook_c",
1212        );
1213        assert_eq!(agg.stop_reason, Some("first reason".into()));
1214    }
1215
1216    #[test]
1217    fn test_aggregated_merge_updated_input_last_wins() {
1218        let mut agg = AggregatedHookOutput::default();
1219
1220        agg.merge(
1221            HookOutput::allow().with_updated_input(json!({"a": 1})),
1222            "hook_a",
1223        );
1224        agg.merge(
1225            HookOutput::allow().with_updated_input(json!({"b": 2})),
1226            "hook_b",
1227        );
1228        assert_eq!(agg.updated_input, Some(json!({"b": 2})));
1229    }
1230
1231    #[test]
1232    fn test_aggregated_has_decision() {
1233        let agg = AggregatedHookOutput::default();
1234        assert!(!agg.has_decision());
1235
1236        let mut agg2 = AggregatedHookOutput::default();
1237        agg2.merge(HookOutput::allow(), "h");
1238        assert!(agg2.has_decision());
1239    }
1240
1241    #[test]
1242    fn test_aggregated_is_denied() {
1243        let mut agg = AggregatedHookOutput::default();
1244        assert!(!agg.is_denied());
1245
1246        agg.merge(HookOutput::deny("no"), "h");
1247        assert!(agg.is_denied());
1248    }
1249
1250    // -- matches_pattern --
1251
1252    #[test]
1253    fn test_matches_pattern_exact() {
1254        assert!(matches_pattern("bash", "bash"));
1255        assert!(!matches_pattern("bash", "write_file"));
1256    }
1257
1258    #[test]
1259    fn test_matches_pattern_pipe_separated() {
1260        assert!(matches_pattern("bash", "bash|write_file"));
1261        assert!(matches_pattern("write_file", "bash|write_file"));
1262        assert!(!matches_pattern("read_file", "bash|write_file"));
1263    }
1264
1265    #[test]
1266    fn test_matches_pattern_wildcard_star() {
1267        assert!(matches_pattern("read_file", "read_*"));
1268        assert!(matches_pattern("read_dir", "read_*"));
1269        assert!(!matches_pattern("write_file", "read_*"));
1270    }
1271
1272    #[test]
1273    fn test_matches_pattern_wildcard_suffix() {
1274        assert!(matches_pattern("read_file", "*_file"));
1275        assert!(matches_pattern("write_file", "*_file"));
1276        assert!(!matches_pattern("read_dir", "*_file"));
1277    }
1278
1279    #[test]
1280    fn test_matches_pattern_wildcard_middle() {
1281        assert!(matches_pattern("pre_tool_use", "pre_*_use"));
1282        assert!(matches_pattern("pre_compact_use", "pre_*_use"));
1283        assert!(!matches_pattern("pre_tool_fail", "pre_*_use"));
1284    }
1285
1286    #[test]
1287    fn test_matches_pattern_star_matches_all() {
1288        assert!(matches_pattern("anything", "*"));
1289        assert!(matches_pattern("", "*"));
1290    }
1291
1292    #[test]
1293    fn test_matches_pattern_pipe_with_wildcard() {
1294        assert!(matches_pattern("read_file", "bash|read_*"));
1295        assert!(matches_pattern("bash", "bash|read_*"));
1296        assert!(!matches_pattern("write_file", "bash|read_*"));
1297    }
1298
1299    // -- HookRegistry --
1300
1301    struct PassthroughHook {
1302        hook_name: String,
1303        hook_events: Vec<HookEvent>,
1304        hook_matcher: Option<String>,
1305    }
1306
1307    impl PassthroughHook {
1308        fn new(name: &str) -> Self {
1309            Self {
1310                hook_name: name.into(),
1311                hook_events: vec![],
1312                hook_matcher: None,
1313            }
1314        }
1315
1316        fn with_events(mut self, events: Vec<HookEvent>) -> Self {
1317            self.hook_events = events;
1318            self
1319        }
1320
1321        fn with_matcher(mut self, matcher: &str) -> Self {
1322            self.hook_matcher = Some(matcher.into());
1323            self
1324        }
1325    }
1326
1327    #[async_trait]
1328    impl Hook for PassthroughHook {
1329        fn name(&self) -> &str {
1330            &self.hook_name
1331        }
1332
1333        fn events(&self) -> &[HookEvent] {
1334            &self.hook_events
1335        }
1336
1337        fn matcher(&self) -> Option<&str> {
1338            self.hook_matcher.as_deref()
1339        }
1340
1341        async fn on_event(&self, _input: &HookInput) -> HookOutput {
1342            HookOutput::passthrough()
1343        }
1344    }
1345
1346    #[test]
1347    fn test_registry_new_empty() {
1348        let reg = HookRegistry::new();
1349        assert!(reg.is_empty());
1350        assert_eq!(reg.len(), 0);
1351    }
1352
1353    #[test]
1354    fn test_registry_register_and_len() {
1355        let mut reg = HookRegistry::new();
1356        reg.register(
1357            Arc::new(PassthroughHook::new("a")),
1358            HookSource::Programmatic,
1359            0,
1360        );
1361        reg.register(
1362            Arc::new(PassthroughHook::new("b")),
1363            HookSource::Programmatic,
1364            0,
1365        );
1366        assert_eq!(reg.len(), 2);
1367    }
1368
1369    #[test]
1370    fn test_registry_remove() {
1371        let mut reg = HookRegistry::new();
1372        reg.register(
1373            Arc::new(PassthroughHook::new("a")),
1374            HookSource::Programmatic,
1375            0,
1376        );
1377        reg.register(
1378            Arc::new(PassthroughHook::new("b")),
1379            HookSource::Programmatic,
1380            0,
1381        );
1382        reg.remove("a");
1383        assert_eq!(reg.len(), 1);
1384        assert_eq!(reg.hooks[0].hook.name(), "b");
1385    }
1386
1387    #[test]
1388    fn test_registry_matching_by_event() {
1389        let mut reg = HookRegistry::new();
1390        reg.register(
1391            Arc::new(PassthroughHook::new("pre_only").with_events(vec![HookEvent::PreToolUse])),
1392            HookSource::Programmatic,
1393            0,
1394        );
1395        reg.register(
1396            Arc::new(PassthroughHook::new("post_only").with_events(vec![HookEvent::PostToolUse])),
1397            HookSource::Programmatic,
1398            0,
1399        );
1400        reg.register(
1401            Arc::new(PassthroughHook::new("all_events")),
1402            HookSource::Programmatic,
1403            0,
1404        );
1405
1406        let input = HookInput::PreToolUse {
1407            tool_name: "bash".into(),
1408            tool_input: json!({}),
1409            call_id: ToolCallId::new("c1"),
1410        };
1411        let matched = reg.matching(&input);
1412        assert_eq!(matched.len(), 2);
1413
1414        let names: Vec<&str> = matched.iter().map(|h| h.hook.name()).collect();
1415        assert!(names.contains(&"pre_only"));
1416        assert!(names.contains(&"all_events"));
1417        assert!(!names.contains(&"post_only"));
1418    }
1419
1420    #[test]
1421    fn test_registry_matching_by_matcher() {
1422        let mut reg = HookRegistry::new();
1423        reg.register(
1424            Arc::new(
1425                PassthroughHook::new("bash_only")
1426                    .with_events(vec![HookEvent::PreToolUse])
1427                    .with_matcher("bash"),
1428            ),
1429            HookSource::Programmatic,
1430            0,
1431        );
1432        reg.register(
1433            Arc::new(
1434                PassthroughHook::new("write_family")
1435                    .with_events(vec![HookEvent::PreToolUse])
1436                    .with_matcher("write_*"),
1437            ),
1438            HookSource::Programmatic,
1439            0,
1440        );
1441
1442        // bash → 匹配 bash_only
1443        let input_bash = HookInput::PreToolUse {
1444            tool_name: "bash".into(),
1445            tool_input: json!({}),
1446            call_id: ToolCallId::new("c1"),
1447        };
1448        let matched = reg.matching(&input_bash);
1449        assert_eq!(matched.len(), 1);
1450        assert_eq!(matched[0].hook.name(), "bash_only");
1451
1452        // write_file → 匹配 write_family
1453        let input_write = HookInput::PreToolUse {
1454            tool_name: "write_file".into(),
1455            tool_input: json!({}),
1456            call_id: ToolCallId::new("c2"),
1457        };
1458        let matched = reg.matching(&input_write);
1459        assert_eq!(matched.len(), 1);
1460        assert_eq!(matched[0].hook.name(), "write_family");
1461
1462        // read_file → 无匹配
1463        let input_read = HookInput::PreToolUse {
1464            tool_name: "read_file".into(),
1465            tool_input: json!({}),
1466            call_id: ToolCallId::new("c3"),
1467        };
1468        let matched = reg.matching(&input_read);
1469        assert!(matched.is_empty());
1470    }
1471
1472    #[test]
1473    fn test_registry_matching_non_tool_event_with_matcher() {
1474        let mut reg = HookRegistry::new();
1475        // 有 matcher 但事件不是 tool 事件 → 不匹配
1476        reg.register(
1477            Arc::new(
1478                PassthroughHook::new("h")
1479                    .with_events(vec![HookEvent::SessionStart])
1480                    .with_matcher("bash"),
1481            ),
1482            HookSource::Programmatic,
1483            0,
1484        );
1485
1486        let input = HookInput::SessionStart {
1487            session_id: SessionId::new(),
1488        };
1489        let matched = reg.matching(&input);
1490        assert!(matched.is_empty());
1491    }
1492
1493    #[test]
1494    fn test_registry_priority_order() {
1495        let mut reg = HookRegistry::new();
1496        reg.register(
1497            Arc::new(PassthroughHook::new("low")),
1498            HookSource::Programmatic,
1499            10,
1500        );
1501        reg.register(
1502            Arc::new(PassthroughHook::new("high")),
1503            HookSource::Programmatic,
1504            -10,
1505        );
1506        reg.register(
1507            Arc::new(PassthroughHook::new("mid")),
1508            HookSource::Programmatic,
1509            0,
1510        );
1511
1512        let input = HookInput::SessionStart {
1513            session_id: SessionId::new(),
1514        };
1515        let matched = reg.matching(&input);
1516        assert_eq!(matched[0].hook.name(), "high");
1517        assert_eq!(matched[1].hook.name(), "mid");
1518        assert_eq!(matched[2].hook.name(), "low");
1519    }
1520
1521    #[test]
1522    fn test_registry_has_hooks_for() {
1523        let mut reg = HookRegistry::new();
1524        reg.register(
1525            Arc::new(PassthroughHook::new("pre_only").with_events(vec![HookEvent::PreToolUse])),
1526            HookSource::Programmatic,
1527            0,
1528        );
1529
1530        assert!(reg.has_hooks_for(HookEvent::PreToolUse));
1531        assert!(!reg.has_hooks_for(HookEvent::PostToolUse));
1532    }
1533
1534    #[test]
1535    fn test_registry_has_hooks_for_all_events() {
1536        let mut reg = HookRegistry::new();
1537        // events() = [] 表示关注所有事件
1538        reg.register(
1539            Arc::new(PassthroughHook::new("global")),
1540            HookSource::Programmatic,
1541            0,
1542        );
1543
1544        for event in ALL_HOOK_EVENTS {
1545            assert!(reg.has_hooks_for(*event));
1546        }
1547    }
1548
1549    // -- HookSource --
1550
1551    #[test]
1552    fn test_hook_source_serde_roundtrip() {
1553        for source in [
1554            HookSource::Settings,
1555            HookSource::Project,
1556            HookSource::Plugin {
1557                name: "linter".into(),
1558            },
1559            HookSource::Programmatic,
1560            HookSource::Session,
1561        ] {
1562            let json_str = serde_json::to_string(&source).unwrap();
1563            let restored: HookSource = serde_json::from_str(&json_str).unwrap();
1564            assert_eq!(source, restored);
1565        }
1566    }
1567
1568    // -- Hook trait async --
1569
1570    #[tokio::test]
1571    async fn test_hook_trait_async_execution() {
1572        struct DenyBashHook;
1573
1574        #[async_trait]
1575        impl Hook for DenyBashHook {
1576            fn name(&self) -> &str {
1577                "deny_bash"
1578            }
1579
1580            fn events(&self) -> &[HookEvent] {
1581                &[HookEvent::PreToolUse]
1582            }
1583
1584            fn matcher(&self) -> Option<&str> {
1585                Some("bash")
1586            }
1587
1588            async fn on_event(&self, input: &HookInput) -> HookOutput {
1589                if let HookInput::PreToolUse { tool_input, .. } = input {
1590                    let cmd = tool_input["command"].as_str().unwrap_or("");
1591                    if cmd.contains("rm -rf") {
1592                        return HookOutput::deny("dangerous command");
1593                    }
1594                }
1595                HookOutput::passthrough()
1596            }
1597        }
1598
1599        let hook: Arc<dyn Hook> = Arc::new(DenyBashHook);
1600
1601        // 安全命令 → passthrough
1602        let safe_input = HookInput::PreToolUse {
1603            tool_name: "bash".into(),
1604            tool_input: json!({"command": "ls -la"}),
1605            call_id: ToolCallId::new("c1"),
1606        };
1607        let output = hook.on_event(&safe_input).await;
1608        assert!(!output.has_decision());
1609
1610        // 危险命令 → deny
1611        let dangerous_input = HookInput::PreToolUse {
1612            tool_name: "bash".into(),
1613            tool_input: json!({"command": "rm -rf /"}),
1614            call_id: ToolCallId::new("c2"),
1615        };
1616        let output = hook.on_event(&dangerous_input).await;
1617        assert!(output.permission.as_ref().unwrap().is_deny());
1618    }
1619
1620    #[tokio::test]
1621    async fn test_hook_trait_dyn_dispatch() {
1622        let hook: Arc<dyn Hook> = Arc::new(PassthroughHook::new("test"));
1623        assert_eq!(hook.name(), "test");
1624
1625        let input = HookInput::SessionStart {
1626            session_id: SessionId::new(),
1627        };
1628        let output = hook.on_event(&input).await;
1629        assert!(!output.has_decision());
1630    }
1631}