Skip to main content

batuta/agent/
hooks.rs

1//! Hook system for `apr code` — mirrors Claude Code's hook events.
2//!
3//! Hooks let a user intercept the agent loop at well-defined moments and
4//! run an external command. The exit code of the command decides what
5//! happens next:
6//!
7//! * `0` → allow (hook did nothing interesting, or approved the action)
8//! * `1` → warn (hook emitted a message; agent continues)
9//! * `2` → block (hook vetoed the action; agent aborts that step)
10//!
11//! This matches Claude Code's `settings.json` → `[[hooks]]` surface 1:1.
12//!
13//! PMAT-CODE-HOOKS-001. This file introduces the event enum, the config
14//! struct that deserializes from `manifest.toml`'s `[[hooks]]` table,
15//! and the registry + runner used by the agent loop. Runtime integration
16//! ships incrementally:
17//!
18//! * SessionStart — wired from `cmd_code` (see `agent/code.rs`)
19//! * PreToolUse / PostToolUse / UserPromptSubmit / Stop / SubagentStop —
20//!   surface ships now; call sites land in PMAT-CODE-HOOKS-RUNTIME-001.
21
22use std::collections::HashMap;
23use std::path::Path;
24use std::process::Command;
25
26use serde::{Deserialize, Serialize};
27
28/// Canonical hook events mirroring Claude Code.
29///
30/// Claude's official list (as of 2026-04) is exactly the six below. Any
31/// additional events APR-specific work wants to surface should land as
32/// a separate enum variant + status_history entry in the parity contract.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34#[serde(rename_all = "PascalCase")]
35pub enum HookEvent {
36    /// Fires once when the agent session starts.
37    SessionStart,
38    /// Fires before every tool call; exit code 2 vetoes the call.
39    PreToolUse,
40    /// Fires after every tool call regardless of success.
41    PostToolUse,
42    /// Fires whenever the user submits a prompt.
43    UserPromptSubmit,
44    /// Fires when the top-level agent is about to terminate.
45    Stop,
46    /// Fires when a sub-agent (spawned via the Task tool) terminates.
47    SubagentStop,
48}
49
50/// A single hook entry deserialized from `manifest.hooks[]`.
51///
52/// ```toml
53/// [[hooks]]
54/// event = "PreToolUse"
55/// matcher = "shell"         # optional — tool name glob/regex
56/// command = "echo 'pre-tool' >> ~/.apr/hooks.log"
57/// ```
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct HookConfig {
60    /// Which event triggers this hook.
61    pub event: HookEvent,
62    /// Optional matcher (tool name, prompt text, etc.). When `None` the
63    /// hook always fires for its event.
64    #[serde(default)]
65    pub matcher: Option<String>,
66    /// Shell command to execute. Ran via `sh -c`.
67    pub command: String,
68    /// Optional timeout in seconds (default 30).
69    #[serde(default = "default_timeout_secs")]
70    pub timeout_secs: u64,
71}
72
73fn default_timeout_secs() -> u64 {
74    30
75}
76
77/// Decision returned by a hook, derived from the command's exit code.
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub enum HookDecision {
80    /// Exit 0 — agent continues as if no hook ran.
81    Allow,
82    /// Exit 1 — agent continues but surfaces the hook's stderr to the user.
83    Warn(String),
84    /// Exit 2 — agent aborts the current step; stderr becomes the reason.
85    Block(String),
86}
87
88impl HookDecision {
89    pub fn from_exit_code(code: i32, stderr: String) -> Self {
90        match code {
91            0 => Self::Allow,
92            1 => Self::Warn(stderr),
93            _ => Self::Block(stderr),
94        }
95    }
96
97    pub fn is_blocking(&self) -> bool {
98        matches!(self, Self::Block(_))
99    }
100}
101
102/// Indexed collection of hooks keyed by event.
103#[derive(Debug, Default)]
104pub struct HookRegistry {
105    by_event: HashMap<HookEvent, Vec<HookConfig>>,
106}
107
108impl HookRegistry {
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    /// Build a registry from a raw list (typically `manifest.hooks`).
114    pub fn from_configs(configs: impl IntoIterator<Item = HookConfig>) -> Self {
115        let mut reg = Self::new();
116        for cfg in configs {
117            reg.register(cfg);
118        }
119        reg
120    }
121
122    pub fn register(&mut self, cfg: HookConfig) {
123        self.by_event.entry(cfg.event).or_default().push(cfg);
124    }
125
126    pub fn hooks_for(&self, event: HookEvent) -> &[HookConfig] {
127        self.by_event.get(&event).map_or(&[], |v| v.as_slice())
128    }
129
130    pub fn len(&self) -> usize {
131        self.by_event.values().map(Vec::len).sum()
132    }
133
134    pub fn is_empty(&self) -> bool {
135        self.len() == 0
136    }
137
138    /// Fire every hook registered for `event` whose `matcher` (if any) is
139    /// contained in `context`. Returns the first blocking decision if
140    /// any, else the collected warnings, else Allow.
141    pub fn run(&self, event: HookEvent, context: &str, cwd: &Path) -> HookDecision {
142        let mut warnings = Vec::new();
143        for cfg in self.hooks_for(event) {
144            if let Some(m) = cfg.matcher.as_deref() {
145                if !context.contains(m) {
146                    continue;
147                }
148            }
149            match run_single(cfg, cwd) {
150                HookDecision::Allow => {}
151                HookDecision::Warn(msg) => warnings.push(msg),
152                block @ HookDecision::Block(_) => return block,
153            }
154        }
155        if warnings.is_empty() {
156            HookDecision::Allow
157        } else {
158            HookDecision::Warn(warnings.join("\n"))
159        }
160    }
161}
162
163fn run_single(cfg: &HookConfig, cwd: &Path) -> HookDecision {
164    let output = match Command::new("sh").arg("-c").arg(&cfg.command).current_dir(cwd).output() {
165        Ok(o) => o,
166        Err(e) => {
167            return HookDecision::Warn(format!("hook '{}' failed to spawn: {e}", cfg.command));
168        }
169    };
170    let code = output.status.code().unwrap_or(0);
171    let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
172    HookDecision::from_exit_code(code, stderr)
173}
174
175#[cfg(test)]
176mod tests;