1use enact_config::{resolve_config_file, HookConfig, HookEvent, HooksConfig};
8use enact_plugins::load_plugins;
9use std::path::Path;
10use std::process::Stdio;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::process::Command;
13use tracing::warn;
14
15#[derive(Debug, Clone)]
17pub struct CommandHookResult {
18 pub success: bool,
19 pub stdout: String,
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct HookRegistry {
25 hooks: Vec<HookConfig>,
26}
27
28impl HookRegistry {
29 pub fn new() -> Self {
31 Self { hooks: Vec::new() }
32 }
33
34 pub fn load_global() -> Self {
37 let hooks = load_global_hooks();
38 Self { hooks }
39 }
40
41 pub fn with_plugin_hooks(mut self, project_dir: Option<&Path>) -> Self {
44 self.hooks.extend(load_plugin_hooks(project_dir));
45 self
46 }
47
48 pub fn with_agent_hooks(mut self, agent_hooks: Option<&[HookConfig]>) -> Self {
51 if let Some(hooks) = agent_hooks {
52 self.hooks.extend(hooks.iter().cloned());
53 }
54 self
55 }
56
57 pub fn load_global_and_agent(
59 project_dir: Option<&Path>,
60 agent_hooks: Option<&[HookConfig]>,
61 ) -> Self {
62 Self::load_global()
63 .with_plugin_hooks(project_dir)
64 .with_agent_hooks(agent_hooks)
65 }
66
67 pub fn hooks_for_event(&self, event: HookEvent, tool_name: Option<&str>) -> Vec<&HookConfig> {
69 self.hooks
70 .iter()
71 .filter(|h| h.event == event)
72 .filter(|h| match (&h.matcher, tool_name) {
73 (None, _) => true,
74 (Some(_), None) => true,
75 (Some(pattern), Some(name)) => regex_match_pattern(pattern, name).unwrap_or(false),
76 })
77 .collect()
78 }
79
80 pub async fn run_command_handler(
82 &self,
83 script: &str,
84 context_json: &serde_json::Value,
85 ) -> std::io::Result<CommandHookResult> {
86 run_command_shell(script, context_json).await
87 }
88}
89
90pub fn load_global_hooks() -> Vec<HookConfig> {
92 match resolve_config_file("hooks.yaml", "ENACT_HOOKS_CONFIG_PATH") {
93 Some(path) => {
94 let content = match std::fs::read_to_string(&path) {
95 Ok(c) => c,
96 Err(_) => return Vec::new(),
97 };
98 let config: HooksConfig = match serde_yaml::from_str(&content) {
99 Ok(c) => c,
100 Err(_) => return Vec::new(),
101 };
102 config.hooks
103 }
104 None => Vec::new(),
105 }
106}
107
108pub fn load_plugin_hooks(project_dir: Option<&Path>) -> Vec<HookConfig> {
110 let mut out = Vec::new();
111 for plugin in load_plugins(project_dir) {
112 let hooks_path = plugin.hooks_dir().join("hooks.yaml");
113 if !hooks_path.exists() {
114 continue;
115 }
116 let content = match std::fs::read_to_string(&hooks_path) {
117 Ok(c) => c,
118 Err(e) => {
119 warn!(
120 "Failed to read plugin hooks from {}: {}",
121 hooks_path.display(),
122 e
123 );
124 continue;
125 }
126 };
127 match serde_yaml::from_str::<HooksConfig>(&content) {
128 Ok(cfg) => out.extend(cfg.hooks),
129 Err(e) => warn!(
130 "Failed to parse plugin hooks from {}: {}",
131 hooks_path.display(),
132 e
133 ),
134 }
135 }
136 out
137}
138
139async fn run_command_shell(
141 script: &str,
142 context_json: &serde_json::Value,
143) -> std::io::Result<CommandHookResult> {
144 let json_str = serde_json::to_string(context_json).unwrap_or_default();
145 let mut child = Command::new("sh")
146 .arg("-c")
147 .arg(script)
148 .stdin(Stdio::piped())
149 .stdout(Stdio::piped())
150 .stderr(Stdio::null())
151 .spawn()?;
152 if let Some(mut stdin) = child.stdin.take() {
153 stdin.write_all(json_str.as_bytes()).await?;
154 stdin.flush().await?;
155 }
156 let mut stdout = String::new();
157 if let Some(mut out) = child.stdout.take() {
158 let mut buf = Vec::new();
159 out.read_to_end(&mut buf).await?;
160 stdout = String::from_utf8_lossy(&buf).to_string();
161 }
162 let status = child.wait().await?;
163 Ok(CommandHookResult {
164 success: status.success(),
165 stdout: stdout.trim().to_string(),
166 })
167}
168
169fn regex_match_pattern(pattern: &str, name: &str) -> Option<bool> {
171 regex::Regex::new(pattern).ok().map(|re| re.is_match(name))
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use std::fs;
178
179 #[test]
180 fn hook_registry_empty() {
181 let reg = HookRegistry::new();
182 assert!(reg
183 .hooks_for_event(HookEvent::SessionStart, None)
184 .is_empty());
185 }
186
187 #[test]
188 fn load_global_returns_vec() {
189 let _ = load_global_hooks();
190 }
191
192 #[test]
193 fn load_global_plugin_agent_order() {
194 let temp = tempfile::tempdir().unwrap();
195 let global_hooks = temp.path().join("hooks.yaml");
196 fs::write(
197 &global_hooks,
198 "hooks:\n - event: SessionStart\n handler:\n type: command\n script: \"echo global\"\n",
199 )
200 .unwrap();
201
202 let plugin_root = temp
203 .path()
204 .join(".enact")
205 .join("plugins")
206 .join("demo-plugin");
207 fs::create_dir_all(plugin_root.join(".enact-plugin")).unwrap();
208 fs::create_dir_all(plugin_root.join("hooks")).unwrap();
209 fs::write(
210 plugin_root.join(".enact-plugin").join("plugin.json"),
211 r#"{"name":"demo-plugin","version":"0.1.0"}"#,
212 )
213 .unwrap();
214 fs::write(
215 plugin_root.join("hooks").join("hooks.yaml"),
216 "hooks:\n - event: SessionStart\n handler:\n type: command\n script: \"echo plugin\"\n",
217 )
218 .unwrap();
219
220 std::env::set_var(
221 "ENACT_HOOKS_CONFIG_PATH",
222 global_hooks.to_string_lossy().as_ref(),
223 );
224 let agent = vec![HookConfig {
225 event: HookEvent::SessionStart,
226 matcher: None,
227 handler: enact_config::HookHandler::Command {
228 script: "echo agent".to_string(),
229 },
230 async_mode: false,
231 }];
232 let registry = HookRegistry::load_global_and_agent(Some(temp.path()), Some(&agent));
233 std::env::remove_var("ENACT_HOOKS_CONFIG_PATH");
234
235 let hooks = registry.hooks_for_event(HookEvent::SessionStart, None);
236 assert_eq!(hooks.len(), 3);
237 let scripts: Vec<String> = hooks
238 .into_iter()
239 .map(|h| match &h.handler {
240 enact_config::HookHandler::Command { script } => script.clone(),
241 _ => String::new(),
242 })
243 .collect();
244 assert_eq!(scripts, vec!["echo global", "echo plugin", "echo agent"]);
245 }
246
247 #[test]
248 fn invalid_plugin_hook_yaml_is_ignored() {
249 let temp = tempfile::tempdir().unwrap();
250 let plugin_root = temp
251 .path()
252 .join(".enact")
253 .join("plugins")
254 .join("demo-plugin");
255 fs::create_dir_all(plugin_root.join(".enact-plugin")).unwrap();
256 fs::create_dir_all(plugin_root.join("hooks")).unwrap();
257 fs::write(
258 plugin_root.join(".enact-plugin").join("plugin.json"),
259 r#"{"name":"demo-plugin","version":"0.1.0"}"#,
260 )
261 .unwrap();
262 fs::write(plugin_root.join("hooks").join("hooks.yaml"), "not: [valid").unwrap();
263
264 let hooks = load_plugin_hooks(Some(temp.path()));
265 assert!(hooks.is_empty());
266 }
267}