Skip to main content

aster/hooks/
loader.rs

1//! Hook 加载器
2//!
3//! 从配置文件加载 hooks
4
5use super::registry::{register_hook, register_legacy_hook, SharedHookRegistry};
6use super::types::{HookConfig, HookEvent, LegacyHookConfig};
7use std::collections::HashMap;
8use std::fs;
9use std::path::Path;
10use tracing::{error, warn};
11
12/// Hooks 配置文件结构(新格式)
13#[allow(dead_code)]
14#[derive(Debug, serde::Deserialize)]
15struct HooksConfigNew {
16    hooks: HashMap<String, serde_json::Value>,
17}
18
19/// Hooks 配置文件结构(旧格式)
20#[allow(dead_code)]
21#[derive(Debug, serde::Deserialize)]
22struct HooksConfigLegacy {
23    hooks: Vec<LegacyHookConfig>,
24}
25
26/// 验证 Hook 事件名称
27fn is_valid_hook_event(event: &str) -> bool {
28    matches!(
29        event,
30        "PreToolUse"
31            | "PostToolUse"
32            | "PostToolUseFailure"
33            | "Notification"
34            | "UserPromptSubmit"
35            | "SessionStart"
36            | "SessionEnd"
37            | "Stop"
38            | "SubagentStart"
39            | "SubagentStop"
40            | "PreCompact"
41            | "PermissionRequest"
42            | "BeforeSetup"
43            | "AfterSetup"
44            | "CommandsLoaded"
45            | "ToolsLoaded"
46            | "McpConfigsLoaded"
47            | "PluginsInitialized"
48            | "AfterHooks"
49    )
50}
51
52/// 解析事件名称
53fn parse_event(event: &str) -> Option<HookEvent> {
54    match event {
55        "PreToolUse" => Some(HookEvent::PreToolUse),
56        "PostToolUse" => Some(HookEvent::PostToolUse),
57        "PostToolUseFailure" => Some(HookEvent::PostToolUseFailure),
58        "Notification" => Some(HookEvent::Notification),
59        "UserPromptSubmit" => Some(HookEvent::UserPromptSubmit),
60        "SessionStart" => Some(HookEvent::SessionStart),
61        "SessionEnd" => Some(HookEvent::SessionEnd),
62        "Stop" => Some(HookEvent::Stop),
63        "SubagentStart" => Some(HookEvent::SubagentStart),
64        "SubagentStop" => Some(HookEvent::SubagentStop),
65        "PreCompact" => Some(HookEvent::PreCompact),
66        "PermissionRequest" => Some(HookEvent::PermissionRequest),
67        "BeforeSetup" => Some(HookEvent::BeforeSetup),
68        "AfterSetup" => Some(HookEvent::AfterSetup),
69        "CommandsLoaded" => Some(HookEvent::CommandsLoaded),
70        "ToolsLoaded" => Some(HookEvent::ToolsLoaded),
71        "McpConfigsLoaded" => Some(HookEvent::McpConfigsLoaded),
72        "PluginsInitialized" => Some(HookEvent::PluginsInitialized),
73        "AfterHooks" => Some(HookEvent::AfterHooks),
74        _ => None,
75    }
76}
77
78/// 从配置文件加载 hooks
79pub fn load_hooks_from_file(config_path: &Path) -> Result<(), String> {
80    if !config_path.exists() {
81        return Ok(());
82    }
83
84    let content = fs::read_to_string(config_path)
85        .map_err(|e| format!("Failed to read {}: {}", config_path.display(), e))?;
86
87    let json: serde_json::Value = serde_json::from_str(&content)
88        .map_err(|e| format!("Failed to parse {}: {}", config_path.display(), e))?;
89
90    // 检查 hooks 字段
91    let hooks = match json.get("hooks") {
92        Some(h) => h,
93        None => return Ok(()),
94    };
95
96    // 新格式:{ "hooks": { "PreToolUse": [...] } }
97    if let Some(obj) = hooks.as_object() {
98        for (event_name, hook_value) in obj {
99            if !is_valid_hook_event(event_name) {
100                warn!("Unknown hook event: {}", event_name);
101                continue;
102            }
103
104            let event = match parse_event(event_name) {
105                Some(e) => e,
106                None => continue,
107            };
108
109            let hook_array = if hook_value.is_array() {
110                hook_value.as_array().unwrap().clone()
111            } else {
112                vec![hook_value.clone()]
113            };
114
115            for hook_json in hook_array {
116                match serde_json::from_value::<HookConfig>(hook_json.clone()) {
117                    Ok(config) => {
118                        register_hook(event, config);
119                    }
120                    Err(e) => {
121                        warn!("Invalid hook config for event {}: {}", event_name, e);
122                    }
123                }
124            }
125        }
126    }
127    // 旧格式:{ "hooks": [...] }
128    else if let Some(arr) = hooks.as_array() {
129        for hook_json in arr {
130            match serde_json::from_value::<LegacyHookConfig>(hook_json.clone()) {
131                Ok(config) => {
132                    register_legacy_hook(config);
133                }
134                Err(e) => {
135                    warn!("Invalid legacy hook config: {}", e);
136                }
137            }
138        }
139    }
140
141    Ok(())
142}
143
144/// 从项目目录加载 hooks
145pub fn load_project_hooks(project_dir: &Path) -> Result<(), String> {
146    // 检查 .claude/settings.json
147    let settings_path = project_dir.join(".claude").join("settings.json");
148    if let Err(e) = load_hooks_from_file(&settings_path) {
149        error!("Failed to load hooks from settings: {}", e);
150    }
151
152    // 检查 .claude/hooks/ 目录
153    let hooks_dir = project_dir.join(".claude").join("hooks");
154    if hooks_dir.exists() && hooks_dir.is_dir() {
155        if let Ok(entries) = fs::read_dir(&hooks_dir) {
156            for entry in entries.flatten() {
157                let path = entry.path();
158                if path.extension().map(|e| e == "json").unwrap_or(false) {
159                    if let Err(e) = load_hooks_from_file(&path) {
160                        error!("Failed to load hooks from {}: {}", path.display(), e);
161                    }
162                }
163            }
164        }
165    }
166
167    Ok(())
168}
169
170/// 从注册表加载 hooks
171pub fn load_hooks_to_registry(
172    config_path: &Path,
173    registry: &SharedHookRegistry,
174) -> Result<(), String> {
175    if !config_path.exists() {
176        return Ok(());
177    }
178
179    let content = fs::read_to_string(config_path)
180        .map_err(|e| format!("Failed to read {}: {}", config_path.display(), e))?;
181
182    let json: serde_json::Value = serde_json::from_str(&content)
183        .map_err(|e| format!("Failed to parse {}: {}", config_path.display(), e))?;
184
185    let hooks = match json.get("hooks") {
186        Some(h) => h,
187        None => return Ok(()),
188    };
189
190    if let Some(obj) = hooks.as_object() {
191        for (event_name, hook_value) in obj {
192            let event = match parse_event(event_name) {
193                Some(e) => e,
194                None => {
195                    warn!("Unknown hook event: {}", event_name);
196                    continue;
197                }
198            };
199
200            let hook_array = if hook_value.is_array() {
201                hook_value.as_array().unwrap().clone()
202            } else {
203                vec![hook_value.clone()]
204            };
205
206            for hook_json in hook_array {
207                match serde_json::from_value::<HookConfig>(hook_json) {
208                    Ok(config) => {
209                        registry.register(event, config);
210                    }
211                    Err(e) => {
212                        warn!("Invalid hook config: {}", e);
213                    }
214                }
215            }
216        }
217    } else if let Some(arr) = hooks.as_array() {
218        for hook_json in arr {
219            match serde_json::from_value::<LegacyHookConfig>(hook_json.clone()) {
220                Ok(config) => {
221                    registry.register_legacy(config);
222                }
223                Err(e) => {
224                    warn!("Invalid legacy hook config: {}", e);
225                }
226            }
227        }
228    }
229
230    Ok(())
231}