Skip to main content

j_agent/infra/hook/
definition.rs

1use crate::infra::hook::types::*;
2use crate::permission::JcliConfig;
3use crate::util::log::{write_error_log, write_info_log};
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::fs;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10// ========== HookKind 枚举 ==========
11
12/// Hook 种类:Shell 命令(子进程)、LLM(prompt 模板调 LLM)、内置 Rust 闭包(进程内)
13#[derive(Clone)]
14pub enum HookKind {
15    /// Shell 命令,通过 `sh -c` 子进程执行(现有行为)
16    Shell(ShellHook),
17    /// LLM hook,通过 prompt 模板调用 LLM API,返回 HookResult JSON
18    Llm(LlmHook),
19    /// 内置 Rust 闭包,进程内零开销执行
20    Builtin(BuiltinHook),
21}
22
23/// Shell hook:一条命令 + 超时 + 失败策略 + 条件过滤
24#[derive(Debug, Clone)]
25pub struct ShellHook {
26    /// Hook 目录名(目录布局下有值,session hook 为 None)
27    pub name: Option<String>,
28    pub command: String,
29    pub timeout: u64,
30    pub retry: u32,
31    pub on_error: OnError,
32    pub filter: HookFilter,
33    /// Hook 目录路径(目录布局下有值,session hook 为 None)
34    pub dir_path: Option<PathBuf>,
35}
36
37/// LLM hook:prompt 模板 + 模型覆盖 + 超时 + 重试 + 失败策略 + 条件过滤
38#[derive(Debug, Clone)]
39pub struct LlmHook {
40    /// Hook 目录名(目录布局下有值,session hook 为 None)
41    pub name: Option<String>,
42    /// Prompt 模板,支持 {{variable}} 模板变量
43    pub prompt: String,
44    /// 模型名覆盖(空则使用当前活跃 provider 的模型)
45    pub model: Option<String>,
46    /// 超时秒数
47    pub timeout: u64,
48    /// 重试次数(仅 Err 路径生效)
49    pub retry: u32,
50    /// 失败策略
51    pub on_error: OnError,
52    /// 条件过滤
53    pub filter: HookFilter,
54    /// Hook 目录路径(目录布局下有值,session hook 为 None)
55    #[allow(dead_code)]
56    pub dir_path: Option<PathBuf>,
57}
58
59/// 内置 hook 的处理函数类型
60pub type BuiltinHookFn = Arc<dyn Fn(&HookContext) -> Option<HookResult> + Send + Sync>;
61
62/// 内置 hook:一个命名的 Rust 闭包
63pub struct BuiltinHook {
64    /// 唯一名称,用于列出/调试(如 "tasks_status"、"todo_nag")
65    pub name: String,
66    /// 实际执行的 Rust 闭包
67    pub handler: BuiltinHookFn,
68}
69
70impl Clone for BuiltinHook {
71    fn clone(&self) -> Self {
72        BuiltinHook {
73            name: self.name.clone(),
74            handler: Arc::clone(&self.handler),
75        }
76    }
77}
78
79impl fmt::Debug for HookKind {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        match self {
82            HookKind::Shell(shell) => f
83                .debug_struct("HookKind::Shell")
84                .field("name", &shell.name)
85                .field("command", &shell.command)
86                .field("timeout", &shell.timeout)
87                .field("on_error", &shell.on_error)
88                .finish(),
89            HookKind::Llm(llm) => f
90                .debug_struct("HookKind::Llm")
91                .field("name", &llm.name)
92                .field("prompt", &llm.prompt.len())
93                .field("model", &llm.model)
94                .field("timeout", &llm.timeout)
95                .field("retry", &llm.retry)
96                .finish(),
97            HookKind::Builtin(builtin) => f
98                .debug_struct("HookKind::Builtin")
99                .field("name", &builtin.name)
100                .finish(),
101        }
102    }
103}
104
105// ========== HookDef(YAML 兼容格式)==========
106
107/// Hook 定义(YAML 兼容):支持 bash 和 llm 两种类型
108///
109/// YAML 示例(bash):
110/// ```yaml
111/// - command: "echo '{\"user_input\": \"hooked\"}'"
112///   timeout: 10
113///   on_error: skip
114/// ```
115///
116/// YAML 示例(llm):
117/// ```yaml
118/// - type: llm
119///   prompt: |
120///     检查以下用户输入是否包含敏感信息:
121///     {{user_input}}
122///     如果包含,返回 action=stop + retry_feedback。
123///   timeout: 30
124///   retry: 1
125///   on_error: skip
126/// ```
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct HookDef {
129    /// Hook 类型:bash(默认)或 llm
130    #[serde(default)]
131    pub r#type: HookType,
132    /// Shell 命令(type=bash 时必填,通过 `sh -c` 执行)
133    #[serde(default, skip_serializing_if = "Option::is_none")]
134    pub command: Option<String>,
135    /// LLM prompt 模板(type=llm 时必填,支持 {{variable}} 模板变量)
136    #[serde(default, skip_serializing_if = "Option::is_none")]
137    pub prompt: Option<String>,
138    /// LLM 模型名覆盖(type=llm 时可选,空则使用当前活跃 provider 的模型)
139    #[serde(default, skip_serializing_if = "Option::is_none")]
140    pub model: Option<String>,
141    /// 超时秒数(bash 默认 10,llm 默认 30)
142    #[serde(default = "default_timeout")]
143    pub timeout: u64,
144    /// 重试次数(仅 Err 路径生效,默认 0 即不重试)
145    #[serde(default)]
146    pub retry: u32,
147    /// 脚本/LLM 失败时的处理策略(默认 skip)
148    #[serde(default)]
149    pub on_error: OnError,
150    /// 条件过滤:仅当条件匹配时执行(默认无过滤)
151    #[serde(default, skip_serializing_if = "HookFilter::is_empty")]
152    pub filter: HookFilter,
153}
154
155impl HookDef {
156    /// 转换为 HookKind(根据 type 字段分派)
157    pub fn into_hook_kind(self) -> Result<HookKind, String> {
158        match self.r#type {
159            HookType::Bash => {
160                let command = self.command.unwrap_or_default();
161                if command.is_empty() {
162                    return Err("bash hook 缺少 command 字段".to_string());
163                }
164                Ok(HookKind::Shell(ShellHook {
165                    name: None,
166                    command,
167                    timeout: self.timeout,
168                    retry: self.retry,
169                    on_error: self.on_error,
170                    filter: self.filter,
171                    dir_path: None,
172                }))
173            }
174            HookType::Llm => {
175                let prompt = self.prompt.unwrap_or_default();
176                if prompt.is_empty() {
177                    return Err("llm hook 缺少 prompt 字段".to_string());
178                }
179                Ok(HookKind::Llm(LlmHook {
180                    name: None,
181                    prompt,
182                    model: self.model,
183                    timeout: if self.timeout == default_timeout() {
184                        default_llm_timeout()
185                    } else {
186                        self.timeout
187                    },
188                    retry: if self.retry == 0 { 1 } else { self.retry },
189                    on_error: self.on_error,
190                    filter: self.filter,
191                    dir_path: None,
192                }))
193            }
194        }
195    }
196}
197
198impl From<HookDef> for HookKind {
199    fn from(def: HookDef) -> Self {
200        def.into_hook_kind().unwrap_or_else(|e| {
201            write_error_log("HookDef::into_hook_kind", &e);
202            // 回退到空 Shell hook(不会执行有效操作,但不会 panic)
203            HookKind::Shell(ShellHook {
204                name: None,
205                command: String::new(),
206                timeout: 0,
207                retry: 0,
208                on_error: OnError::Skip,
209                filter: HookFilter::default(),
210                dir_path: None,
211            })
212        })
213    }
214}
215
216// ========== HookDirDef(目录布局下的 HOOK.yaml / HOOK.yml 格式)==========
217
218/// HOOK.yaml / HOOK.yml 定义(目录布局下的格式)
219///
220/// 与 `HookDef` 的区别:`events` 为列表(一个 hook 可绑定多个事件),无 `command`/`prompt` 以外的不必要字段。
221/// 目录布局下 `command` 中的相对路径以 hook 目录为 cwd 解析。
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct HookDirDef {
224    /// 绑定的事件列表
225    pub events: Vec<HookEvent>,
226    /// Hook 类型
227    #[serde(default)]
228    pub r#type: HookType,
229    /// Shell 命令(type=bash 时必填,通过 `sh -c` 执行,cwd 为 hook 目录)
230    #[serde(default, skip_serializing_if = "Option::is_none")]
231    pub command: Option<String>,
232    /// LLM prompt 模板(type=llm 时必填,支持 {{variable}} 模板变量)
233    #[serde(default, skip_serializing_if = "Option::is_none")]
234    pub prompt: Option<String>,
235    /// LLM 模型名覆盖
236    #[serde(default, skip_serializing_if = "Option::is_none")]
237    pub model: Option<String>,
238    /// 超时秒数(bash 默认 10,llm 默认 30)
239    #[serde(default = "default_timeout")]
240    pub timeout: u64,
241    /// 重试次数(仅 Err 路径生效)
242    #[serde(default)]
243    pub retry: u32,
244    /// 失败策略
245    #[serde(default)]
246    pub on_error: OnError,
247    /// 条件过滤
248    #[serde(default, skip_serializing_if = "HookFilter::is_empty")]
249    pub filter: HookFilter,
250}
251
252impl HookDirDef {
253    /// 转换为 `Vec<(HookEvent, HookKind)>`(每个 event 一个条目)
254    pub fn into_hook_kinds(
255        self,
256        name: &str,
257        dir_path: &Path,
258    ) -> Result<Vec<(HookEvent, HookKind)>, String> {
259        if self.events.is_empty() {
260            return Err(format!("hook '{}' 的 events 为空", name));
261        }
262        let kind = match self.r#type {
263            HookType::Bash => {
264                let command = self.command.unwrap_or_default();
265                if command.is_empty() {
266                    return Err(format!("bash hook '{}' 缺少 command 字段", name));
267                }
268                HookKind::Shell(ShellHook {
269                    name: Some(name.to_string()),
270                    command,
271                    timeout: self.timeout,
272                    retry: self.retry,
273                    on_error: self.on_error,
274                    filter: self.filter,
275                    dir_path: Some(dir_path.to_path_buf()),
276                })
277            }
278            HookType::Llm => {
279                let prompt = self.prompt.unwrap_or_default();
280                if prompt.is_empty() {
281                    return Err(format!("llm hook '{}' 缺少 prompt 字段", name));
282                }
283                HookKind::Llm(LlmHook {
284                    name: Some(name.to_string()),
285                    prompt,
286                    model: self.model,
287                    timeout: if self.timeout == default_timeout() {
288                        default_llm_timeout()
289                    } else {
290                        self.timeout
291                    },
292                    retry: if self.retry == 0 { 1 } else { self.retry },
293                    on_error: self.on_error,
294                    filter: self.filter,
295                    dir_path: Some(dir_path.to_path_buf()),
296                })
297            }
298        };
299        Ok(self.events.into_iter().map(|e| (e, kind.clone())).collect())
300    }
301}
302
303// ========== 目录加载函数 ==========
304
305/// 返回用户级 hooks 目录: ~/.jdata/agent/hooks/
306pub fn hooks_dir() -> PathBuf {
307    let dir = crate::constants::data_root().join("agent").join("hooks");
308    let _ = fs::create_dir_all(&dir);
309    dir
310}
311
312/// 返回项目级 hooks 目录: .jcli/hooks/(如果存在)
313pub fn project_hooks_dir() -> Option<PathBuf> {
314    let config_dir = JcliConfig::find_config_dir()?;
315    let dir = config_dir.join("hooks");
316    if dir.is_dir() { Some(dir) } else { None }
317}
318
319/// 从指定目录加载 hooks(遍历子目录,解析 HOOK.yaml 或 HOOK.yml)
320pub(crate) fn load_hooks_from_dir(
321    dir: &Path,
322    source_name: &str,
323) -> Vec<(String, HookDirDef, PathBuf)> {
324    let mut hooks = Vec::new();
325    let entries = match fs::read_dir(dir) {
326        Ok(e) => e,
327        Err(_) => return hooks,
328    };
329    for entry in entries.flatten() {
330        let path = entry.path();
331        if !path.is_dir() {
332            continue;
333        }
334        let hook_name = path
335            .file_name()
336            .unwrap_or_default()
337            .to_string_lossy()
338            .to_string();
339
340        // 跳过 example 目录(模板示例,不是实际可执行的 hook)
341        if hook_name == "example" {
342            continue;
343        }
344
345        // 优先 HOOK.yaml,其次 HOOK.yml;两者共存时取 HOOK.yaml
346        let hook_yaml = if path.join("HOOK.yaml").exists() {
347            path.join("HOOK.yaml")
348        } else if path.join("HOOK.yml").exists() {
349            path.join("HOOK.yml")
350        } else {
351            continue;
352        };
353        let yaml_file_name = hook_yaml
354            .file_name()
355            .unwrap_or_default()
356            .to_string_lossy()
357            .to_string();
358        match fs::read_to_string(&hook_yaml) {
359            Ok(content) => match serde_yaml::from_str::<HookDirDef>(&content) {
360                Ok(def) => {
361                    if def.events.is_empty() {
362                        write_error_log(
363                            "load_hooks_from_dir",
364                            &format!("hook '{}' 的 events 为空,跳过", hook_name),
365                        );
366                        continue;
367                    }
368                    hooks.push((hook_name, def, path));
369                }
370                Err(e) => write_error_log(
371                    "load_hooks_from_dir",
372                    &format!("解析 {}/{} 失败: {}", hook_name, yaml_file_name, e),
373                ),
374            },
375            Err(e) => write_error_log(
376                "load_hooks_from_dir",
377                &format!("读取 {}/{} 失败: {}", hook_name, yaml_file_name, e),
378            ),
379        }
380    }
381    write_info_log(
382        "load_hooks_from_dir",
383        &format!("从 {} 加载了 {} 个 hook", source_name, hooks.len()),
384    );
385    hooks
386}