Skip to main content

enact_runner/
hooks.rs

1//! Hook configuration and registry — load hooks from ~/.enact/hooks.yaml, plugins, and agent YAML.
2//!
3//! Hooks fire at lifecycle points (SessionStart, PreToolUse, PostToolUse, etc.).
4//! Handler types: command (shell script), prompt (LLM), agent (sub-agent).
5//! Global hooks are merged with per-agent hooks from agent.yaml.
6
7use enact_config::{resolve_config_file, HookConfig, HookEvent, HooksConfig};
8use enact_plugins::load_plugins;
9use std::path::Path;
10use std::process::Stdio;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::process::Command;
13use tracing::warn;
14
15/// Result of executing a command hook.
16#[derive(Debug, Clone)]
17pub struct CommandHookResult {
18    pub success: bool,
19    pub stdout: String,
20}
21
22/// Registry of hooks (global + per-agent merged).
23#[derive(Debug, Clone, Default)]
24pub struct HookRegistry {
25    hooks: Vec<HookConfig>,
26}
27
28impl HookRegistry {
29    /// Create an empty registry.
30    pub fn new() -> Self {
31        Self { hooks: Vec::new() }
32    }
33
34    /// Load global hooks from ~/.enact/hooks.yaml (or ENACT_HOOKS_CONFIG_PATH).
35    /// Returns a new registry; does not merge with existing.
36    pub fn load_global() -> Self {
37        let hooks = load_global_hooks();
38        Self { hooks }
39    }
40
41    /// Merge in plugin hooks loaded from `<plugin>/hooks/hooks.yaml`.
42    /// Plugin hooks run after global hooks and before agent hooks.
43    pub fn with_plugin_hooks(mut self, project_dir: Option<&Path>) -> Self {
44        self.hooks.extend(load_plugin_hooks(project_dir));
45        self
46    }
47
48    /// Merge in agent-level hooks (from agent.yaml `hooks:`).
49    /// Agent hooks are appended so they run after global hooks for the same event.
50    pub fn with_agent_hooks(mut self, agent_hooks: Option<&[HookConfig]>) -> Self {
51        if let Some(hooks) = agent_hooks {
52            self.hooks.extend(hooks.iter().cloned());
53        }
54        self
55    }
56
57    /// Build a registry with global hooks loaded and optional agent hooks merged.
58    pub fn load_global_and_agent(
59        project_dir: Option<&Path>,
60        agent_hooks: Option<&[HookConfig]>,
61    ) -> Self {
62        Self::load_global()
63            .with_plugin_hooks(project_dir)
64            .with_agent_hooks(agent_hooks)
65    }
66
67    /// Get all hooks for a given event. Optional tool name for matcher filter (e.g. PreToolUse).
68    pub fn hooks_for_event(&self, event: HookEvent, tool_name: Option<&str>) -> Vec<&HookConfig> {
69        self.hooks
70            .iter()
71            .filter(|h| h.event == event)
72            .filter(|h| match (&h.matcher, tool_name) {
73                (None, _) => true,
74                (Some(_), None) => true,
75                (Some(pattern), Some(name)) => regex_match_pattern(pattern, name).unwrap_or(false),
76            })
77            .collect()
78    }
79
80    /// Run a command handler: pass JSON context on stdin and capture stdout.
81    pub async fn run_command_handler(
82        &self,
83        script: &str,
84        context_json: &serde_json::Value,
85    ) -> std::io::Result<CommandHookResult> {
86        run_command_shell(script, context_json).await
87    }
88}
89
90/// Load global hooks from hooks.yaml (ENACT_HOOKS_CONFIG_PATH or ~/.enact/hooks.yaml).
91pub fn load_global_hooks() -> Vec<HookConfig> {
92    match resolve_config_file("hooks.yaml", "ENACT_HOOKS_CONFIG_PATH") {
93        Some(path) => {
94            let content = match std::fs::read_to_string(&path) {
95                Ok(c) => c,
96                Err(_) => return Vec::new(),
97            };
98            let config: HooksConfig = match serde_yaml::from_str(&content) {
99                Ok(c) => c,
100                Err(_) => return Vec::new(),
101            };
102            config.hooks
103        }
104        None => Vec::new(),
105    }
106}
107
108/// Load plugin hooks from `<plugin>/hooks/hooks.yaml`.
109pub fn load_plugin_hooks(project_dir: Option<&Path>) -> Vec<HookConfig> {
110    let mut out = Vec::new();
111    for plugin in load_plugins(project_dir) {
112        let hooks_path = plugin.hooks_dir().join("hooks.yaml");
113        if !hooks_path.exists() {
114            continue;
115        }
116        let content = match std::fs::read_to_string(&hooks_path) {
117            Ok(c) => c,
118            Err(e) => {
119                warn!(
120                    "Failed to read plugin hooks from {}: {}",
121                    hooks_path.display(),
122                    e
123                );
124                continue;
125            }
126        };
127        match serde_yaml::from_str::<HooksConfig>(&content) {
128            Ok(cfg) => out.extend(cfg.hooks),
129            Err(e) => warn!(
130                "Failed to parse plugin hooks from {}: {}",
131                hooks_path.display(),
132                e
133            ),
134        }
135    }
136    out
137}
138
139/// Run a shell script with JSON context on stdin.
140async fn run_command_shell(
141    script: &str,
142    context_json: &serde_json::Value,
143) -> std::io::Result<CommandHookResult> {
144    let json_str = serde_json::to_string(context_json).unwrap_or_default();
145    let mut child = Command::new("sh")
146        .arg("-c")
147        .arg(script)
148        .stdin(Stdio::piped())
149        .stdout(Stdio::piped())
150        .stderr(Stdio::null())
151        .spawn()?;
152    if let Some(mut stdin) = child.stdin.take() {
153        stdin.write_all(json_str.as_bytes()).await?;
154        stdin.flush().await?;
155    }
156    let mut stdout = String::new();
157    if let Some(mut out) = child.stdout.take() {
158        let mut buf = Vec::new();
159        out.read_to_end(&mut buf).await?;
160        stdout = String::from_utf8_lossy(&buf).to_string();
161    }
162    let status = child.wait().await?;
163    Ok(CommandHookResult {
164        success: status.success(),
165        stdout: stdout.trim().to_string(),
166    })
167}
168
169/// Match tool name against hook matcher pattern (regex string).
170fn regex_match_pattern(pattern: &str, name: &str) -> Option<bool> {
171    regex::Regex::new(pattern).ok().map(|re| re.is_match(name))
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use std::fs;
178
179    #[test]
180    fn hook_registry_empty() {
181        let reg = HookRegistry::new();
182        assert!(reg
183            .hooks_for_event(HookEvent::SessionStart, None)
184            .is_empty());
185    }
186
187    #[test]
188    fn load_global_returns_vec() {
189        let _ = load_global_hooks();
190    }
191
192    #[test]
193    fn load_global_plugin_agent_order() {
194        let temp = tempfile::tempdir().unwrap();
195        let global_hooks = temp.path().join("hooks.yaml");
196        fs::write(
197            &global_hooks,
198            "hooks:\n  - event: SessionStart\n    handler:\n      type: command\n      script: \"echo global\"\n",
199        )
200        .unwrap();
201
202        let plugin_root = temp
203            .path()
204            .join(".enact")
205            .join("plugins")
206            .join("demo-plugin");
207        fs::create_dir_all(plugin_root.join(".enact-plugin")).unwrap();
208        fs::create_dir_all(plugin_root.join("hooks")).unwrap();
209        fs::write(
210            plugin_root.join(".enact-plugin").join("plugin.json"),
211            r#"{"name":"demo-plugin","version":"0.1.0"}"#,
212        )
213        .unwrap();
214        fs::write(
215            plugin_root.join("hooks").join("hooks.yaml"),
216            "hooks:\n  - event: SessionStart\n    handler:\n      type: command\n      script: \"echo plugin\"\n",
217        )
218        .unwrap();
219
220        std::env::set_var(
221            "ENACT_HOOKS_CONFIG_PATH",
222            global_hooks.to_string_lossy().as_ref(),
223        );
224        let agent = vec![HookConfig {
225            event: HookEvent::SessionStart,
226            matcher: None,
227            handler: enact_config::HookHandler::Command {
228                script: "echo agent".to_string(),
229            },
230            async_mode: false,
231        }];
232        let registry = HookRegistry::load_global_and_agent(Some(temp.path()), Some(&agent));
233        std::env::remove_var("ENACT_HOOKS_CONFIG_PATH");
234
235        let hooks = registry.hooks_for_event(HookEvent::SessionStart, None);
236        assert_eq!(hooks.len(), 3);
237        let scripts: Vec<String> = hooks
238            .into_iter()
239            .map(|h| match &h.handler {
240                enact_config::HookHandler::Command { script } => script.clone(),
241                _ => String::new(),
242            })
243            .collect();
244        assert_eq!(scripts, vec!["echo global", "echo plugin", "echo agent"]);
245    }
246
247    #[test]
248    fn invalid_plugin_hook_yaml_is_ignored() {
249        let temp = tempfile::tempdir().unwrap();
250        let plugin_root = temp
251            .path()
252            .join(".enact")
253            .join("plugins")
254            .join("demo-plugin");
255        fs::create_dir_all(plugin_root.join(".enact-plugin")).unwrap();
256        fs::create_dir_all(plugin_root.join("hooks")).unwrap();
257        fs::write(
258            plugin_root.join(".enact-plugin").join("plugin.json"),
259            r#"{"name":"demo-plugin","version":"0.1.0"}"#,
260        )
261        .unwrap();
262        fs::write(plugin_root.join("hooks").join("hooks.yaml"), "not: [valid").unwrap();
263
264        let hooks = load_plugin_hooks(Some(temp.path()));
265        assert!(hooks.is_empty());
266    }
267}