Skip to main content

j_agent/infra/hook/
definition.rs

1use crate::infra::hook::types::*;
2use crate::permission::JcliConfig;
3use crate::util::log::{write_error_log, write_info_log};
4use serde::{Deserialize, Serialize};
5use std::path::{Path, PathBuf};
6use std::sync::Arc;
7
8// ========== HookKind 枚举 ==========
9
10/// Hook 种类:Shell 命令(子进程)、LLM(prompt 模板调 LLM)、内置 Rust 闭包(进程内)
11#[derive(Clone)]
12pub enum HookKind {
13    /// Shell 命令,通过 `sh -c` 子进程执行(现有行为)
14    Shell(ShellHook),
15    /// LLM hook,通过 prompt 模板调用 LLM API,返回 HookResult JSON
16    Llm(LlmHook),
17    /// 内置 Rust 闭包,进程内零开销执行
18    Builtin(BuiltinHook),
19}
20
21/// Shell hook:一条命令 + 超时 + 失败策略 + 条件过滤
22#[derive(Debug, Clone)]
23pub struct ShellHook {
24    /// Hook 目录名(目录布局下有值,session hook 为 None)
25    pub name: Option<String>,
26    pub command: String,
27    pub timeout: u64,
28    pub retry: u32,
29    pub on_error: OnError,
30    pub filter: HookFilter,
31    /// Hook 目录路径(目录布局下有值,session hook 为 None)
32    pub dir_path: Option<PathBuf>,
33}
34
35/// LLM hook:prompt 模板 + 模型覆盖 + 超时 + 重试 + 失败策略 + 条件过滤
36#[derive(Debug, Clone)]
37pub struct LlmHook {
38    /// Hook 目录名(目录布局下有值,session hook 为 None)
39    pub name: Option<String>,
40    /// Prompt 模板,支持 {{variable}} 模板变量
41    pub prompt: String,
42    /// 模型名覆盖(空则使用当前活跃 provider 的模型)
43    pub model: Option<String>,
44    /// 超时秒数
45    pub timeout: u64,
46    /// 重试次数(仅 Err 路径生效)
47    pub retry: u32,
48    /// 失败策略
49    pub on_error: OnError,
50    /// 条件过滤
51    pub filter: HookFilter,
52    /// Hook 目录路径(目录布局下有值,session hook 为 None)
53    #[allow(dead_code)]
54    pub dir_path: Option<PathBuf>,
55}
56
57/// 内置 hook 的处理函数类型
58pub type BuiltinHookFn = Arc<dyn Fn(&HookContext) -> Option<HookResult> + Send + Sync>;
59
60/// 内置 hook:一个命名的 Rust 闭包
61pub struct BuiltinHook {
62    /// 唯一名称,用于列出/调试(如 "tasks_status"、"todo_nag")
63    pub name: String,
64    /// 实际执行的 Rust 闭包
65    pub handler: BuiltinHookFn,
66}
67
68impl Clone for BuiltinHook {
69    fn clone(&self) -> Self {
70        BuiltinHook {
71            name: self.name.clone(),
72            handler: Arc::clone(&self.handler),
73        }
74    }
75}
76
77impl std::fmt::Debug for HookKind {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match self {
80            HookKind::Shell(shell) => f
81                .debug_struct("HookKind::Shell")
82                .field("name", &shell.name)
83                .field("command", &shell.command)
84                .field("timeout", &shell.timeout)
85                .field("on_error", &shell.on_error)
86                .finish(),
87            HookKind::Llm(llm) => f
88                .debug_struct("HookKind::Llm")
89                .field("name", &llm.name)
90                .field("prompt", &llm.prompt.len())
91                .field("model", &llm.model)
92                .field("timeout", &llm.timeout)
93                .field("retry", &llm.retry)
94                .finish(),
95            HookKind::Builtin(builtin) => f
96                .debug_struct("HookKind::Builtin")
97                .field("name", &builtin.name)
98                .finish(),
99        }
100    }
101}
102
103// ========== HookDef(YAML 兼容格式)==========
104
105/// Hook 定义(YAML 兼容):支持 bash 和 llm 两种类型
106///
107/// YAML 示例(bash):
108/// ```yaml
109/// - command: "echo '{\"user_input\": \"hooked\"}'"
110///   timeout: 10
111///   on_error: skip
112/// ```
113///
114/// YAML 示例(llm):
115/// ```yaml
116/// - type: llm
117///   prompt: |
118///     检查以下用户输入是否包含敏感信息:
119///     {{user_input}}
120///     如果包含,返回 action=stop + retry_feedback。
121///   timeout: 30
122///   retry: 1
123///   on_error: skip
124/// ```
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct HookDef {
127    /// Hook 类型:bash(默认)或 llm
128    #[serde(default)]
129    pub r#type: HookType,
130    /// Shell 命令(type=bash 时必填,通过 `sh -c` 执行)
131    #[serde(default, skip_serializing_if = "Option::is_none")]
132    pub command: Option<String>,
133    /// LLM prompt 模板(type=llm 时必填,支持 {{variable}} 模板变量)
134    #[serde(default, skip_serializing_if = "Option::is_none")]
135    pub prompt: Option<String>,
136    /// LLM 模型名覆盖(type=llm 时可选,空则使用当前活跃 provider 的模型)
137    #[serde(default, skip_serializing_if = "Option::is_none")]
138    pub model: Option<String>,
139    /// 超时秒数(bash 默认 10,llm 默认 30)
140    #[serde(default = "default_timeout")]
141    pub timeout: u64,
142    /// 重试次数(仅 Err 路径生效,默认 0 即不重试)
143    #[serde(default)]
144    pub retry: u32,
145    /// 脚本/LLM 失败时的处理策略(默认 skip)
146    #[serde(default)]
147    pub on_error: OnError,
148    /// 条件过滤:仅当条件匹配时执行(默认无过滤)
149    #[serde(default, skip_serializing_if = "HookFilter::is_empty")]
150    pub filter: HookFilter,
151}
152
153impl HookDef {
154    /// 转换为 HookKind(根据 type 字段分派)
155    pub fn into_hook_kind(self) -> Result<HookKind, String> {
156        match self.r#type {
157            HookType::Bash => {
158                let command = self.command.unwrap_or_default();
159                if command.is_empty() {
160                    return Err("bash hook 缺少 command 字段".to_string());
161                }
162                Ok(HookKind::Shell(ShellHook {
163                    name: None,
164                    command,
165                    timeout: self.timeout,
166                    retry: self.retry,
167                    on_error: self.on_error,
168                    filter: self.filter,
169                    dir_path: None,
170                }))
171            }
172            HookType::Llm => {
173                let prompt = self.prompt.unwrap_or_default();
174                if prompt.is_empty() {
175                    return Err("llm hook 缺少 prompt 字段".to_string());
176                }
177                Ok(HookKind::Llm(LlmHook {
178                    name: None,
179                    prompt,
180                    model: self.model,
181                    timeout: if self.timeout == default_timeout() {
182                        default_llm_timeout()
183                    } else {
184                        self.timeout
185                    },
186                    retry: if self.retry == 0 { 1 } else { self.retry },
187                    on_error: self.on_error,
188                    filter: self.filter,
189                    dir_path: None,
190                }))
191            }
192        }
193    }
194}
195
196impl From<HookDef> for HookKind {
197    fn from(def: HookDef) -> Self {
198        def.into_hook_kind().unwrap_or_else(|e| {
199            write_error_log("HookDef::into_hook_kind", &e);
200            // 回退到空 Shell hook(不会执行有效操作,但不会 panic)
201            HookKind::Shell(ShellHook {
202                name: None,
203                command: String::new(),
204                timeout: 0,
205                retry: 0,
206                on_error: OnError::Skip,
207                filter: HookFilter::default(),
208                dir_path: None,
209            })
210        })
211    }
212}
213
214// ========== HookDirDef(目录布局下的 HOOK.yaml / HOOK.yml 格式)==========
215
216/// HOOK.yaml / HOOK.yml 定义(目录布局下的格式)
217///
218/// 与 `HookDef` 的区别:`events` 为列表(一个 hook 可绑定多个事件),无 `command`/`prompt` 以外的不必要字段。
219/// 目录布局下 `command` 中的相对路径以 hook 目录为 cwd 解析。
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct HookDirDef {
222    /// 绑定的事件列表
223    pub events: Vec<HookEvent>,
224    /// Hook 类型
225    #[serde(default)]
226    pub r#type: HookType,
227    /// Shell 命令(type=bash 时必填,通过 `sh -c` 执行,cwd 为 hook 目录)
228    #[serde(default, skip_serializing_if = "Option::is_none")]
229    pub command: Option<String>,
230    /// LLM prompt 模板(type=llm 时必填,支持 {{variable}} 模板变量)
231    #[serde(default, skip_serializing_if = "Option::is_none")]
232    pub prompt: Option<String>,
233    /// LLM 模型名覆盖
234    #[serde(default, skip_serializing_if = "Option::is_none")]
235    pub model: Option<String>,
236    /// 超时秒数(bash 默认 10,llm 默认 30)
237    #[serde(default = "default_timeout")]
238    pub timeout: u64,
239    /// 重试次数(仅 Err 路径生效)
240    #[serde(default)]
241    pub retry: u32,
242    /// 失败策略
243    #[serde(default)]
244    pub on_error: OnError,
245    /// 条件过滤
246    #[serde(default, skip_serializing_if = "HookFilter::is_empty")]
247    pub filter: HookFilter,
248}
249
250impl HookDirDef {
251    /// 转换为 `Vec<(HookEvent, HookKind)>`(每个 event 一个条目)
252    pub fn into_hook_kinds(
253        self,
254        name: &str,
255        dir_path: &Path,
256    ) -> Result<Vec<(HookEvent, HookKind)>, String> {
257        if self.events.is_empty() {
258            return Err(format!("hook '{}' 的 events 为空", name));
259        }
260        let kind = match self.r#type {
261            HookType::Bash => {
262                let command = self.command.unwrap_or_default();
263                if command.is_empty() {
264                    return Err(format!("bash hook '{}' 缺少 command 字段", name));
265                }
266                HookKind::Shell(ShellHook {
267                    name: Some(name.to_string()),
268                    command,
269                    timeout: self.timeout,
270                    retry: self.retry,
271                    on_error: self.on_error,
272                    filter: self.filter,
273                    dir_path: Some(dir_path.to_path_buf()),
274                })
275            }
276            HookType::Llm => {
277                let prompt = self.prompt.unwrap_or_default();
278                if prompt.is_empty() {
279                    return Err(format!("llm hook '{}' 缺少 prompt 字段", name));
280                }
281                HookKind::Llm(LlmHook {
282                    name: Some(name.to_string()),
283                    prompt,
284                    model: self.model,
285                    timeout: if self.timeout == default_timeout() {
286                        default_llm_timeout()
287                    } else {
288                        self.timeout
289                    },
290                    retry: if self.retry == 0 { 1 } else { self.retry },
291                    on_error: self.on_error,
292                    filter: self.filter,
293                    dir_path: Some(dir_path.to_path_buf()),
294                })
295            }
296        };
297        Ok(self.events.into_iter().map(|e| (e, kind.clone())).collect())
298    }
299}
300
301// ========== 目录加载函数 ==========
302
303/// 返回用户级 hooks 目录: ~/.jdata/agent/hooks/
304pub fn hooks_dir() -> PathBuf {
305    let dir = crate::constants::data_root().join("agent").join("hooks");
306    let _ = std::fs::create_dir_all(&dir);
307    dir
308}
309
310/// 返回项目级 hooks 目录: .jcli/hooks/(如果存在)
311pub fn project_hooks_dir() -> Option<PathBuf> {
312    let config_dir = JcliConfig::find_config_dir()?;
313    let dir = config_dir.join("hooks");
314    if dir.is_dir() { Some(dir) } else { None }
315}
316
317/// 从指定目录加载 hooks(遍历子目录,解析 HOOK.yaml 或 HOOK.yml)
318pub(crate) fn load_hooks_from_dir(
319    dir: &Path,
320    source_name: &str,
321) -> Vec<(String, HookDirDef, PathBuf)> {
322    let mut hooks = Vec::new();
323    let entries = match std::fs::read_dir(dir) {
324        Ok(e) => e,
325        Err(_) => return hooks,
326    };
327    for entry in entries.flatten() {
328        let path = entry.path();
329        if !path.is_dir() {
330            continue;
331        }
332        let hook_name = path
333            .file_name()
334            .unwrap_or_default()
335            .to_string_lossy()
336            .to_string();
337
338        // 跳过 example 目录(模板示例,不是实际可执行的 hook)
339        if hook_name == "example" {
340            continue;
341        }
342
343        // 优先 HOOK.yaml,其次 HOOK.yml;两者共存时取 HOOK.yaml
344        let hook_yaml = if path.join("HOOK.yaml").exists() {
345            path.join("HOOK.yaml")
346        } else if path.join("HOOK.yml").exists() {
347            path.join("HOOK.yml")
348        } else {
349            continue;
350        };
351        let yaml_file_name = hook_yaml
352            .file_name()
353            .unwrap_or_default()
354            .to_string_lossy()
355            .to_string();
356        match std::fs::read_to_string(&hook_yaml) {
357            Ok(content) => match serde_yaml::from_str::<HookDirDef>(&content) {
358                Ok(def) => {
359                    if def.events.is_empty() {
360                        write_error_log(
361                            "load_hooks_from_dir",
362                            &format!("hook '{}' 的 events 为空,跳过", hook_name),
363                        );
364                        continue;
365                    }
366                    hooks.push((hook_name, def, path));
367                }
368                Err(e) => write_error_log(
369                    "load_hooks_from_dir",
370                    &format!("解析 {}/{} 失败: {}", hook_name, yaml_file_name, e),
371                ),
372            },
373            Err(e) => write_error_log(
374                "load_hooks_from_dir",
375                &format!("读取 {}/{} 失败: {}", hook_name, yaml_file_name, e),
376            ),
377        }
378    }
379    write_info_log(
380        "load_hooks_from_dir",
381        &format!("从 {} 加载了 {} 个 hook", source_name, hooks.len()),
382    );
383    hooks
384}