Skip to main content

albert_runtime/
hooks.rs

1use std::ffi::OsStr;
2use std::process::Command;
3
4use serde_json::json;
5
6use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum HookEvent {
10    PreToolUse,
11    PostToolUse,
12}
13
14impl HookEvent {
15    fn as_str(self) -> &'static str {
16        match self {
17            Self::PreToolUse => "PreToolUse",
18            Self::PostToolUse => "PostToolUse",
19        }
20    }
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct HookRunResult {
25    denied: bool,
26    messages: Vec<String>,
27}
28
29impl HookRunResult {
30    #[must_use]
31    pub fn allow(messages: Vec<String>) -> Self {
32        Self {
33            denied: false,
34            messages,
35        }
36    }
37
38    #[must_use]
39    pub fn is_denied(&self) -> bool {
40        self.denied
41    }
42
43    #[must_use]
44    pub fn messages(&self) -> &[String] {
45        &self.messages
46    }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, Default)]
50pub struct HookRunner {
51    config: RuntimeHookConfig,
52}
53
54impl HookRunner {
55    #[must_use]
56    pub fn new(config: RuntimeHookConfig) -> Self {
57        Self { config }
58    }
59
60    #[must_use]
61    pub fn from_feature_config(feature_config: &RuntimeFeatureConfig) -> Self {
62        Self::new(feature_config.hooks().clone())
63    }
64
65    #[must_use]
66    pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
67        self.run_commands(
68            HookEvent::PreToolUse,
69            self.config.pre_tool_use(),
70            tool_name,
71            tool_input,
72            None,
73            false,
74        )
75    }
76
77    #[must_use]
78    pub fn run_post_tool_use(
79        &self,
80        tool_name: &str,
81        tool_input: &str,
82        tool_output: &str,
83        is_error: bool,
84    ) -> HookRunResult {
85        self.run_commands(
86            HookEvent::PostToolUse,
87            self.config.post_tool_use(),
88            tool_name,
89            tool_input,
90            Some(tool_output),
91            is_error,
92        )
93    }
94
95    fn run_commands(
96        &self,
97        event: HookEvent,
98        commands: &[String],
99        tool_name: &str,
100        tool_input: &str,
101        tool_output: Option<&str>,
102        is_error: bool,
103    ) -> HookRunResult {
104        if commands.is_empty() {
105            return HookRunResult::allow(Vec::new());
106        }
107
108        let payload = json!({
109            "hook_event_name": event.as_str(),
110            "tool_name": tool_name,
111            "tool_input": parse_tool_input(tool_input),
112            "tool_input_json": tool_input,
113            "tool_output": tool_output,
114            "tool_result_is_error": is_error,
115        })
116        .to_string();
117
118        let mut messages = Vec::new();
119
120        for command in commands {
121            match self.run_command(
122                command,
123                event,
124                tool_name,
125                tool_input,
126                tool_output,
127                is_error,
128                &payload,
129            ) {
130                HookCommandOutcome::Allow { message } => {
131                    if let Some(message) = message {
132                        messages.push(message);
133                    }
134                }
135                HookCommandOutcome::Deny { message } => {
136                    let message = message.unwrap_or_else(|| {
137                        format!("{} hook denied tool `{tool_name}`", event.as_str())
138                    });
139                    messages.push(message);
140                    return HookRunResult {
141                        denied: true,
142                        messages,
143                    };
144                }
145                HookCommandOutcome::Warn { message } => messages.push(message),
146            }
147        }
148
149        HookRunResult::allow(messages)
150    }
151
152    fn run_command(
153        &self,
154        command: &str,
155        event: HookEvent,
156        tool_name: &str,
157        tool_input: &str,
158        tool_output: Option<&str>,
159        is_error: bool,
160        payload: &str,
161    ) -> HookCommandOutcome {
162        let mut child = shell_command(command);
163        child.stdin(std::process::Stdio::piped());
164        child.stdout(std::process::Stdio::piped());
165        child.stderr(std::process::Stdio::piped());
166        child.env("HOOK_EVENT", event.as_str());
167        child.env("HOOK_TOOL_NAME", tool_name);
168        child.env("HOOK_TOOL_INPUT", tool_input);
169        child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" });
170        if let Some(tool_output) = tool_output {
171            child.env("HOOK_TOOL_OUTPUT", tool_output);
172        }
173
174        match child.output_with_stdin(payload.as_bytes()) {
175            Ok(output) => {
176                let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
177                let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
178                let message = (!stdout.is_empty()).then_some(stdout);
179                match output.status.code() {
180                    Some(0) => HookCommandOutcome::Allow { message },
181                    Some(2) => HookCommandOutcome::Deny { message },
182                    Some(code) => HookCommandOutcome::Warn {
183                        message: format_hook_warning(
184                            command,
185                            code,
186                            message.as_deref(),
187                            stderr.as_str(),
188                        ),
189                    },
190                    None => HookCommandOutcome::Warn {
191                        message: format!(
192                            "{} hook `{command}` terminated by signal while handling `{tool_name}`",
193                            event.as_str()
194                        ),
195                    },
196                }
197            }
198            Err(error) => HookCommandOutcome::Warn {
199                message: format!(
200                    "{} hook `{command}` failed to start for `{tool_name}`: {error}",
201                    event.as_str()
202                ),
203            },
204        }
205    }
206}
207
208enum HookCommandOutcome {
209    Allow { message: Option<String> },
210    Deny { message: Option<String> },
211    Warn { message: String },
212}
213
214fn parse_tool_input(tool_input: &str) -> serde_json::Value {
215    serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input }))
216}
217
218fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
219    let mut message =
220        format!("Hook `{command}` exited with status {code}; allowing tool execution to continue");
221    if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
222        message.push_str(": ");
223        message.push_str(stdout);
224    } else if !stderr.is_empty() {
225        message.push_str(": ");
226        message.push_str(stderr);
227    }
228    message
229}
230
231fn shell_command(command: &str) -> CommandWithStdin {
232    #[cfg(windows)]
233    let mut command_builder = {
234        let mut command_builder = Command::new("cmd");
235        command_builder.arg("/C").arg(command);
236        CommandWithStdin::new(command_builder)
237    };
238
239    #[cfg(not(windows))]
240    let command_builder = {
241        let mut command_builder = Command::new("sh");
242        command_builder.arg("-lc").arg(command);
243        CommandWithStdin::new(command_builder)
244    };
245
246    command_builder
247}
248
249struct CommandWithStdin {
250    command: Command,
251}
252
253impl CommandWithStdin {
254    fn new(command: Command) -> Self {
255        Self { command }
256    }
257
258    fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self {
259        self.command.stdin(cfg);
260        self
261    }
262
263    fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self {
264        self.command.stdout(cfg);
265        self
266    }
267
268    fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self {
269        self.command.stderr(cfg);
270        self
271    }
272
273    fn env<K, V>(&mut self, key: K, value: V) -> &mut Self
274    where
275        K: AsRef<OsStr>,
276        V: AsRef<OsStr>,
277    {
278        self.command.env(key, value);
279        self
280    }
281
282    fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> {
283        let mut child = self.command.spawn()?;
284        if let Some(mut child_stdin) = child.stdin.take() {
285            use std::io::Write;
286            child_stdin.write_all(stdin)?;
287        }
288        child.wait_with_output()
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::{HookRunResult, HookRunner};
295    use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
296
297    #[test]
298    fn allows_exit_code_zero_and_captures_stdout() {
299        let runner = HookRunner::new(RuntimeHookConfig::new(
300            vec![shell_snippet("printf 'pre ok'")],
301            Vec::new(),
302        ));
303
304        let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#);
305
306        assert_eq!(result, HookRunResult::allow(vec!["pre ok".to_string()]));
307    }
308
309    #[test]
310    fn denies_exit_code_two() {
311        let runner = HookRunner::new(RuntimeHookConfig::new(
312            vec![shell_snippet("printf 'blocked by hook'; exit 2")],
313            Vec::new(),
314        ));
315
316        let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
317
318        assert!(result.is_denied());
319        assert_eq!(result.messages(), &["blocked by hook".to_string()]);
320    }
321
322    #[test]
323    fn warns_for_other_non_zero_statuses() {
324        let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks(
325            RuntimeHookConfig::new(
326                vec![shell_snippet("printf 'warning hook'; exit 1")],
327                Vec::new(),
328            ),
329        ));
330
331        let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#);
332
333        assert!(!result.is_denied());
334        assert!(result
335            .messages()
336            .iter()
337            .any(|message| message.contains("allowing tool execution to continue")));
338    }
339
340    #[cfg(windows)]
341    fn shell_snippet(script: &str) -> String {
342        script.replace('\'', "\"")
343    }
344
345    #[cfg(not(windows))]
346    fn shell_snippet(script: &str) -> String {
347        script.to_string()
348    }
349}