Skip to main content

j_agent/infra/hook/
types.rs

1use crate::constants::{HOOK_DEFAULT_LLM_TIMEOUT_SECS, HOOK_DEFAULT_TIMEOUT_SECS};
2use crate::storage::ChatMessage;
3use serde::{Deserialize, Serialize};
4use std::env;
5
6// ========== 常量 ==========
7
8/// Hook 链总超时(秒):整条链执行超过此时间后,中止未执行的 hook
9pub(crate) const MAX_CHAIN_DURATION_SECS: u64 = 30;
10
11// ========== HookEvent ==========
12
13/// Hook 事件类型
14///
15/// 各事件的触发时机及可读/可写字段:
16///
17/// stop / skip 语义(统一规则):
18/// - `stop`:中止当前步骤及其所属子管线(不发送/不请求/不结束/不保存/中止 compact)
19/// - `skip`:跳过当前步骤,同级步骤继续(仅 PreToolExecution:跳过该工具,其他工具继续)
20///
21/// | 事件                          | 触发时机           | 可读字段                              | 可写字段(HookResult 中返回即生效)        |
22/// |-------------------------------|--------------------|-----------------------------------------|----------------------------------------------|
23/// | `PreSendMessage`              | 用户消息入队前     | `user_input`, `messages`               | `user_input`(修改发送内容), `action=stop`, `retry_feedback` |
24/// | `PostSendMessage`             | 用户消息入队后     | `user_input`, `messages`               | 仅通知,返回值被忽略                         |
25/// | `PreLlmRequest`               | LLM API 请求前     | `messages`, `system_prompt`, `model`   | `messages`, `system_prompt`, `inject_messages`, `additional_context`, `action=stop`, `retry_feedback` |
26/// | `PostLlmResponse`             | LLM 回复完成后     | `assistant_output`, `messages`, `model` | `assistant_output`(修改最终回复), `action=stop`, `retry_feedback`, `system_message` |
27/// | `PreToolExecution`            | 工具执行前         | `tool_name`, `tool_arguments`          | `tool_arguments`(修改参数), `action=skip`  |
28/// | `PostToolExecution`           | 工具执行后         | `tool_name`, `tool_result`             | `tool_result`(修改结果)                    |
29/// | `PostToolExecutionFailure`    | 工具执行失败后     | `tool_name`, `tool_error`              | `tool_error`(修改错误信息), `additional_context` |
30/// | `Stop`                        | LLM 即将结束回复   | `user_input`(回复文本), `messages`, `system_prompt`, `model` | `retry_feedback`(带反馈重试), `additional_context`, `action=stop` |
31/// | `PreMicroCompact`             | micro_compact 前   | `messages`, `model`                   | `action=stop`                               |
32/// | `PostMicroCompact`            | micro_compact 后   | `messages`                             | `messages`(修改压缩结果)                    |
33/// | `PreAutoCompact`              | auto_compact 前    | `messages`, `system_prompt`, `model`   | `additional_context`, `action=stop`         |
34/// | `PostAutoCompact`             | auto_compact 后    | `messages`                             | `messages`(修改压缩结果)                    |
35/// | `SessionStart`                | 会话启动时         | `messages`                             | 仅通知,返回值被忽略                         |
36/// | `SessionEnd`                  | 会话退出时         | `messages`                             | 仅通知,返回值被忽略                         |
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum HookEvent {
40    PreSendMessage,
41    PostSendMessage,
42    PreLlmRequest,
43    PostLlmResponse,
44    PreToolExecution,
45    PostToolExecution,
46    PostToolExecutionFailure,
47    Stop,
48    PreMicroCompact,
49    PostMicroCompact,
50    PreAutoCompact,
51    PostAutoCompact,
52    SessionStart,
53    SessionEnd,
54}
55
56impl std::str::FromStr for HookEvent {
57    type Err = ();
58
59    fn from_str(s: &str) -> Result<Self, Self::Err> {
60        match s {
61            "pre_send_message" => Ok(HookEvent::PreSendMessage),
62            "post_send_message" => Ok(HookEvent::PostSendMessage),
63            "pre_llm_request" => Ok(HookEvent::PreLlmRequest),
64            "post_llm_response" => Ok(HookEvent::PostLlmResponse),
65            "pre_tool_execution" => Ok(HookEvent::PreToolExecution),
66            "post_tool_execution" => Ok(HookEvent::PostToolExecution),
67            "post_tool_execution_failure" => Ok(HookEvent::PostToolExecutionFailure),
68            "stop" => Ok(HookEvent::Stop),
69            "pre_micro_compact" => Ok(HookEvent::PreMicroCompact),
70            "post_micro_compact" => Ok(HookEvent::PostMicroCompact),
71            "pre_auto_compact" => Ok(HookEvent::PreAutoCompact),
72            "post_auto_compact" => Ok(HookEvent::PostAutoCompact),
73            "session_start" => Ok(HookEvent::SessionStart),
74            "session_end" => Ok(HookEvent::SessionEnd),
75            _ => Err(()),
76        }
77    }
78}
79
80impl HookEvent {
81    /// 返回 Hook 事件的字符串标识(如 "pre_send_message")
82    pub fn as_str(&self) -> &'static str {
83        match self {
84            HookEvent::PreSendMessage => "pre_send_message",
85            HookEvent::PostSendMessage => "post_send_message",
86            HookEvent::PreLlmRequest => "pre_llm_request",
87            HookEvent::PostLlmResponse => "post_llm_response",
88            HookEvent::PreToolExecution => "pre_tool_execution",
89            HookEvent::PostToolExecution => "post_tool_execution",
90            HookEvent::PostToolExecutionFailure => "post_tool_execution_failure",
91            HookEvent::Stop => "stop",
92            HookEvent::PreMicroCompact => "pre_micro_compact",
93            HookEvent::PostMicroCompact => "post_micro_compact",
94            HookEvent::PreAutoCompact => "pre_auto_compact",
95            HookEvent::PostAutoCompact => "post_auto_compact",
96            HookEvent::SessionStart => "session_start",
97            HookEvent::SessionEnd => "session_end",
98        }
99    }
100
101    /// 返回所有 HookEvent 枚举值的静态切片,用于遍历/校验
102    pub fn all() -> &'static [HookEvent] {
103        &[
104            HookEvent::PreSendMessage,
105            HookEvent::PostSendMessage,
106            HookEvent::PreLlmRequest,
107            HookEvent::PostLlmResponse,
108            HookEvent::PreToolExecution,
109            HookEvent::PostToolExecution,
110            HookEvent::PostToolExecutionFailure,
111            HookEvent::Stop,
112            HookEvent::PreMicroCompact,
113            HookEvent::PostMicroCompact,
114            HookEvent::PreAutoCompact,
115            HookEvent::PostAutoCompact,
116            HookEvent::SessionStart,
117            HookEvent::SessionEnd,
118        ]
119    }
120
121    /// 从字符串解析,不匹配时返回 None
122    pub fn parse(s: &str) -> Option<HookEvent> {
123        s.parse().ok()
124    }
125}
126
127// ========== OnError ==========
128
129/// Shell hook 失败时的处理策略
130#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq)]
131#[serde(rename_all = "snake_case")]
132pub enum OnError {
133    /// 记录错误日志后继续执行后续 hook(默认)
134    #[default]
135    Skip,
136    /// 中止整条 hook 链
137    Stop,
138}
139
140// ========== HookFilter ==========
141
142/// Hook 条件过滤:仅当条件匹配时才执行该 hook
143///
144/// 所有字段为可选,未设置的字段不参与过滤(即视为匹配)。
145/// 多个字段同时设置时取 AND 关系。
146#[derive(Debug, Clone, Serialize, Deserialize, Default)]
147pub struct HookFilter {
148    /// 工具名过滤(精确匹配,仅对工具相关事件生效)
149    #[serde(default, skip_serializing_if = "Option::is_none")]
150    pub tool_name: Option<String>,
151    /// 工具名模式匹配(管道分隔,如 "Write|Edit|Bash",仅对工具相关事件生效)
152    /// 优先级低于 tool_name:当 tool_name 设置时忽略此字段
153    #[serde(default, skip_serializing_if = "Option::is_none")]
154    pub tool_matcher: Option<String>,
155    /// 模型名前缀过滤(如 "gpt-4" 匹配 "gpt-4o"、"gpt-4-turbo")
156    #[serde(default, skip_serializing_if = "Option::is_none")]
157    pub model_prefix: Option<String>,
158}
159
160impl HookFilter {
161    /// 是否为空过滤器(无任何条件,始终匹配)
162    pub fn is_empty(&self) -> bool {
163        self.tool_name.is_none() && self.tool_matcher.is_none() && self.model_prefix.is_none()
164    }
165
166    /// 根据 HookContext 判断是否匹配
167    pub fn matches(&self, context: &HookContext) -> bool {
168        // 精确匹配 tool_name(优先级最高)
169        if let Some(ref expected_tool) = self.tool_name {
170            match &context.tool_name {
171                Some(actual) if actual == expected_tool => {}
172                Some(_) => return false,
173                None => return false,
174            }
175        } else if let Some(ref pattern) = self.tool_matcher {
176            // 管道分隔模式匹配(如 "Write|Edit|Bash")
177            let actual = match &context.tool_name {
178                Some(a) => a,
179                None => return false,
180            };
181            let matched = pattern.split('|').any(|p| p.trim() == actual);
182            if !matched {
183                return false;
184            }
185        }
186        if let Some(ref prefix) = self.model_prefix {
187            match &context.model {
188                Some(actual) if actual.starts_with(prefix.as_str()) => {}
189                Some(_) => return false,
190                None => return false,
191            }
192        }
193        true
194    }
195}
196
197// ========== HookType ==========
198
199/// Hook 类型(YAML `type` 字段)
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
201#[serde(rename_all = "snake_case")]
202pub enum HookType {
203    /// Shell 命令 hook(默认,通过 `sh -c` 子进程执行)
204    #[default]
205    Bash,
206    /// LLM hook(通过 prompt 模板调用 LLM,返回 HookResult JSON)
207    Llm,
208}
209
210impl std::fmt::Display for HookType {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        match self {
213            HookType::Bash => write!(f, "bash"),
214            HookType::Llm => write!(f, "llm"),
215        }
216    }
217}
218
219// ========== HookContext ==========
220
221/// Hook 执行上下文(通过 stdin JSON 传给脚本)
222///
223/// 各字段按事件类型有选择性地填充,未填充的字段序列化时会被跳过(`skip_serializing_if`)。
224/// 脚本可通过 stdin 读取此 JSON 来获取当前事件的上下文信息。
225#[derive(Debug, Serialize)]
226pub struct HookContext {
227    /// 当前触发的事件类型
228    pub event: HookEvent,
229
230    /// 当前对话的完整消息列表
231    /// - 可读事件:PreSendMessage, PostSendMessage, PreLlmRequest, PostLlmResponse, SessionStart, SessionEnd
232    #[serde(skip_serializing_if = "Option::is_none")]
233    pub messages: Option<Vec<ChatMessage>>,
234
235    /// 当前系统提示词
236    /// - 可读事件:PreLlmRequest
237    #[serde(skip_serializing_if = "Option::is_none")]
238    pub system_prompt: Option<String>,
239
240    /// 当前使用的模型名称
241    /// - 可读事件:PreLlmRequest
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub model: Option<String>,
244
245    /// 本轮用户输入的消息文本
246    /// - 可读事件:PreSendMessage(发送前,可通过 HookResult 修改)、PostSendMessage(发送后,只读)
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub user_input: Option<String>,
249
250    /// 本轮 AI 回复的完整文本
251    /// - 可读事件:PostLlmResponse(可通过 HookResult 修改最终展示内容)
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub assistant_output: Option<String>,
254
255    /// 当前工具调用的工具名
256    /// - 可读事件:PreToolExecution, PostToolExecution
257    #[serde(skip_serializing_if = "Option::is_none")]
258    pub tool_name: Option<String>,
259
260    /// 当前工具调用的参数 JSON 字符串
261    /// - 可读事件:PreToolExecution(可通过 HookResult 修改)
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub tool_arguments: Option<String>,
264
265    /// 工具执行的结果内容
266    /// - 可读事件:PostToolExecution(可通过 HookResult 修改)
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub tool_result: Option<String>,
269
270    /// 工具执行失败原因
271    /// - 可读事件:PostToolExecutionFailure(可通过 HookResult 修改)
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub tool_error: Option<String>,
274
275    /// 当前会话 ID
276    /// - 可读事件:所有事件
277    #[serde(skip_serializing_if = "Option::is_none")]
278    pub session_id: Option<String>,
279
280    /// 当前工作目录
281    pub cwd: String,
282}
283
284impl Default for HookContext {
285    fn default() -> Self {
286        Self {
287            event: HookEvent::SessionStart,
288            messages: None,
289            system_prompt: None,
290            model: None,
291            user_input: None,
292            assistant_output: None,
293            tool_name: None,
294            tool_arguments: None,
295            tool_result: None,
296            tool_error: None,
297            session_id: None,
298            cwd: env::current_dir()
299                .map(|p| p.display().to_string())
300                .unwrap_or_else(|_| ".".to_string()),
301        }
302    }
303}
304
305// ========== HookAction / HookResult / HookOutcome ==========
306
307/// Hook 脚本返回结果中的控制流动作
308///
309/// - `action: "stop"` 中止当前步骤及其所属子管线
310/// - `action: "skip"` 跳过当前步骤,同级步骤继续
311/// - 旧字段 `abort=true` 等价于 `action="stop"`
312#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
313#[serde(rename_all = "snake_case")]
314pub enum HookAction {
315    /// 中止当前步骤及其所属子管线
316    Stop,
317    /// 跳过当前步骤,同级步骤继续
318    Skip,
319}
320
321/// Hook 执行结果:允许替换消息列表、系统提示词、用户输入、工具参数等
322#[derive(Debug, Deserialize, Default)]
323pub struct HookResult {
324    /// 替换消息列表(PreLlmRequest)
325    #[serde(default)]
326    pub messages: Option<Vec<ChatMessage>>,
327    /// 替换系统提示词(PreLlmRequest)
328    #[serde(default)]
329    pub system_prompt: Option<String>,
330    /// 替换用户输入文本(PreSendMessage)
331    #[serde(default)]
332    pub user_input: Option<String>,
333    /// 替换 AI 回复文本(PostLlmResponse)
334    #[serde(default)]
335    pub assistant_output: Option<String>,
336    /// 替换工具调用参数(PreToolExecution)
337    #[serde(default)]
338    pub tool_arguments: Option<String>,
339    /// 替换工具执行结果(PostToolExecution)
340    #[serde(default)]
341    pub tool_result: Option<String>,
342    /// 替换工具执行失败原因(PostToolExecutionFailure)
343    #[serde(default)]
344    pub tool_error: Option<String>,
345    /// 追加消息到消息列表末尾(PreLlmRequest)
346    #[serde(default)]
347    pub inject_messages: Option<Vec<ChatMessage>>,
348    /// 审查反馈(Pre*/Stop/PostLlmResponse):中止时附带反馈文本,触发 LLM 带反馈重试
349    #[serde(default)]
350    pub retry_feedback: Option<String>,
351    /// 注入到模型上下文的额外信息(PreLlmRequest/Stop/PreAutoCompact):纯文本追加到 system_prompt 末尾
352    #[serde(default)]
353    pub additional_context: Option<String>,
354    /// 展示给用户的系统消息(所有事件:UI 上以 toast/提示形式显示)
355    #[serde(default)]
356    pub system_message: Option<String>,
357    /// 控制流动作:`stop` = 中止当前步骤及其所属子管线,`skip` = 跳过当前步骤(同级继续)
358    #[serde(default)]
359    pub action: Option<HookAction>,
360}
361
362impl HookResult {
363    /// 是否请求 stop(中止当前步骤及其所属子管线)
364    pub fn is_stop(&self) -> bool {
365        self.action == Some(HookAction::Stop)
366    }
367
368    /// 是否请求 skip(跳过当前步骤,同级继续)
369    pub fn is_skip(&self) -> bool {
370        self.action == Some(HookAction::Skip)
371    }
372
373    /// 是否请求 stop 或 skip(任何控制流中断)
374    pub fn is_halt(&self) -> bool {
375        self.is_stop() || self.is_skip()
376    }
377}
378
379/// Hook 执行的三态结果
380///
381/// - `Success`:执行成功,可能包含修改
382/// - `Retry`:执行失败但还有重试机会
383/// - `Err`:执行失败(重试耗尽或不可重试)
384#[derive(Debug)]
385#[allow(dead_code, clippy::large_enum_variant)]
386pub(crate) enum HookOutcome {
387    Success(HookResult),
388    Retry {
389        error: String,
390        #[allow(dead_code)]
391        attempts_left: u32,
392    },
393    Err(String),
394}
395
396// ========== 辅助常量函数 ==========
397
398pub(crate) fn default_timeout() -> u64 {
399    HOOK_DEFAULT_TIMEOUT_SECS
400}
401
402pub(crate) fn default_llm_timeout() -> u64 {
403    HOOK_DEFAULT_LLM_TIMEOUT_SECS
404}