1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Duration;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum HookEvent {
12 PreToolUse,
14 PostToolUse,
16 PostToolUseFailure,
18 UserPromptSubmit,
20 Stop,
22 SubagentStop,
24 PreCompact,
26 Notification,
28}
29
30impl HookEvent {
31 pub fn is_supported(&self) -> bool {
33 matches!(
34 self,
35 HookEvent::PreToolUse
36 | HookEvent::PostToolUse
37 | HookEvent::PostToolUseFailure
38 | HookEvent::UserPromptSubmit
39 | HookEvent::Stop
40 )
41 }
42}
43
44#[derive(Clone)]
46pub struct HookMatcher {
47 pub event: HookEvent,
48 pub tool_name: Option<String>,
49 pub callback: HookCallback,
50 pub timeout: Option<Duration>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct HookInput {
56 pub event: HookEvent,
57 pub tool_name: Option<String>,
58 pub tool_input: Option<Value>,
59 pub tool_output: Option<Value>,
60 pub prompt: Option<String>,
61 pub session_id: String,
62 #[serde(flatten)]
63 pub extra: Value,
64}
65
66#[derive(Debug, Clone)]
68pub struct HookContext {
69 pub session_id: String,
70 pub cwd: String,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct HookOutput {
76 pub decision: HookDecision,
77 #[serde(default)]
78 pub updated_input: Option<Value>,
79 #[serde(default)]
80 pub message: Option<String>,
81}
82
83#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(rename_all = "snake_case")]
86pub enum HookDecision {
87 Continue,
89 Block,
91 Skip,
93}
94
95impl Default for HookOutput {
96 fn default() -> Self {
97 Self {
98 decision: HookDecision::Continue,
99 updated_input: None,
100 message: None,
101 }
102 }
103}
104
105pub type HookCallback = Arc<
107 dyn Fn(HookInput, HookContext) -> Pin<Box<dyn Future<Output = HookOutput> + Send>>
108 + Send
109 + Sync,
110>;
111
112pub(crate) async fn execute_hooks(
114 hooks: &[HookMatcher],
115 input: HookInput,
116 context: &HookContext,
117 default_timeout: Duration,
118) -> HookOutput {
119 for hook in hooks {
120 if hook.event != input.event {
121 continue;
122 }
123
124 if let Some(pattern) = &hook.tool_name {
126 if let Some(tool_name) = &input.tool_name {
127 if !tool_name_matches(tool_name, pattern) {
128 continue;
129 }
130 } else {
131 continue;
133 }
134 }
135
136 let timeout = hook.timeout.unwrap_or(default_timeout);
137 let result = tokio::time::timeout(
138 timeout,
139 (hook.callback)(input.clone(), context.clone()),
140 )
141 .await;
142
143 match result {
144 Ok(output) => {
145 if output.decision != HookDecision::Skip {
146 return output;
147 }
148 }
149 Err(_) => {
150 tracing::warn!("Hook timed out for event {:?}", input.event);
151 }
152 }
153 }
154
155 HookOutput::default()
156}
157
158#[allow(dead_code)]
160fn tool_name_matches(name: &str, pattern: &str) -> bool {
161 if pattern.ends_with('*') {
162 name.starts_with(pattern.strip_suffix('*').unwrap_or(pattern))
163 } else {
164 name == pattern
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_hook_event_is_supported() {
174 assert!(HookEvent::PreToolUse.is_supported());
175 assert!(HookEvent::PostToolUse.is_supported());
176 assert!(HookEvent::PostToolUseFailure.is_supported());
177 assert!(HookEvent::UserPromptSubmit.is_supported());
178 assert!(HookEvent::Stop.is_supported());
179 assert!(!HookEvent::SubagentStop.is_supported());
180 assert!(!HookEvent::PreCompact.is_supported());
181 assert!(!HookEvent::Notification.is_supported());
182 }
183
184 #[test]
185 fn test_tool_name_exact_match() {
186 assert!(tool_name_matches("EditFile", "EditFile"));
187 assert!(!tool_name_matches("EditFile", "ReadFile"));
188 }
189
190 #[test]
191 fn test_tool_name_glob_match() {
192 assert!(tool_name_matches("EditFile", "Edit*"));
193 assert!(tool_name_matches("EditBlock", "Edit*"));
194 assert!(!tool_name_matches("ReadFile", "Edit*"));
195 }
196
197 fn make_input(event: HookEvent) -> HookInput {
198 HookInput {
199 event,
200 tool_name: None,
201 tool_input: None,
202 tool_output: None,
203 prompt: None,
204 session_id: "test-session".to_string(),
205 extra: serde_json::Value::Null,
206 }
207 }
208
209 fn make_context() -> HookContext {
210 HookContext {
211 session_id: "test-session".to_string(),
212 cwd: "/tmp".to_string(),
213 }
214 }
215
216 #[tokio::test]
217 async fn test_execute_hooks_no_match() {
218 let hooks = vec![];
219 let input = make_input(HookEvent::PreToolUse);
220 let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
221 assert_eq!(output.decision, HookDecision::Continue);
222 }
223
224 #[tokio::test]
225 async fn test_execute_hooks_matching() {
226 let hooks = vec![HookMatcher {
227 event: HookEvent::PreToolUse,
228 tool_name: None,
229 callback: Arc::new(|_input, _ctx| {
230 Box::pin(async {
231 HookOutput {
232 decision: HookDecision::Block,
233 updated_input: None,
234 message: Some("blocked".to_string()),
235 }
236 })
237 }),
238 timeout: None,
239 }];
240 let input = make_input(HookEvent::PreToolUse);
241 let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
242 assert_eq!(output.decision, HookDecision::Block);
243 }
244
245 #[tokio::test]
246 async fn test_execute_hooks_wrong_event() {
247 let hooks = vec![HookMatcher {
248 event: HookEvent::PostToolUse,
249 tool_name: None,
250 callback: Arc::new(|_input, _ctx| {
251 Box::pin(async {
252 HookOutput {
253 decision: HookDecision::Block,
254 updated_input: None,
255 message: None,
256 }
257 })
258 }),
259 timeout: None,
260 }];
261 let input = make_input(HookEvent::PreToolUse);
262 let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
263 assert_eq!(output.decision, HookDecision::Continue);
264 }
265
266 #[tokio::test]
267 async fn test_execute_hooks_tool_name_filter() {
268 let hooks = vec![HookMatcher {
269 event: HookEvent::PreToolUse,
270 tool_name: Some("EditFile".to_string()),
271 callback: Arc::new(|_input, _ctx| {
272 Box::pin(async {
273 HookOutput {
274 decision: HookDecision::Block,
275 updated_input: None,
276 message: None,
277 }
278 })
279 }),
280 timeout: None,
281 }];
282
283 let mut input = make_input(HookEvent::PreToolUse);
285 input.tool_name = Some("EditFile".to_string());
286 let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
287 assert_eq!(output.decision, HookDecision::Block);
288
289 let mut input2 = make_input(HookEvent::PreToolUse);
291 input2.tool_name = Some("ReadFile".to_string());
292 let output2 = execute_hooks(&hooks, input2, &make_context(), Duration::from_secs(5)).await;
293 assert_eq!(output2.decision, HookDecision::Continue);
294 }
295
296 #[tokio::test]
297 async fn test_execute_hooks_glob_filter() {
298 let hooks = vec![HookMatcher {
299 event: HookEvent::PreToolUse,
300 tool_name: Some("Edit*".to_string()),
301 callback: Arc::new(|_input, _ctx| {
302 Box::pin(async {
303 HookOutput {
304 decision: HookDecision::Block,
305 updated_input: None,
306 message: None,
307 }
308 })
309 }),
310 timeout: None,
311 }];
312
313 let mut input = make_input(HookEvent::PreToolUse);
314 input.tool_name = Some("EditBlock".to_string());
315 let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
316 assert_eq!(output.decision, HookDecision::Block);
317 }
318
319 #[tokio::test]
320 async fn test_execute_hooks_timeout() {
321 let hooks = vec![HookMatcher {
322 event: HookEvent::PreToolUse,
323 tool_name: None,
324 callback: Arc::new(|_input, _ctx| {
325 Box::pin(async {
326 tokio::time::sleep(Duration::from_secs(10)).await;
327 HookOutput::default()
328 })
329 }),
330 timeout: Some(Duration::from_millis(10)),
331 }];
332 let input = make_input(HookEvent::PreToolUse);
333 let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
334 assert_eq!(output.decision, HookDecision::Continue);
336 }
337
338 #[tokio::test]
339 async fn test_execute_hooks_skip_advances() {
340 let hooks = vec![
342 HookMatcher {
343 event: HookEvent::PreToolUse,
344 tool_name: None,
345 callback: Arc::new(|_input, _ctx| {
346 Box::pin(async {
347 HookOutput {
348 decision: HookDecision::Skip,
349 updated_input: None,
350 message: None,
351 }
352 })
353 }),
354 timeout: None,
355 },
356 HookMatcher {
357 event: HookEvent::PreToolUse,
358 tool_name: None,
359 callback: Arc::new(|_input, _ctx| {
360 Box::pin(async {
361 HookOutput {
362 decision: HookDecision::Block,
363 updated_input: None,
364 message: None,
365 }
366 })
367 }),
368 timeout: None,
369 },
370 ];
371 let input = make_input(HookEvent::PreToolUse);
372 let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
373 assert_eq!(output.decision, HookDecision::Block);
374 }
375}