1use std::sync::Arc;
29use std::time::Duration;
30
31use serde::{Deserialize, Serialize};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
37#[serde(rename_all = "snake_case")]
38pub enum HookEvent {
39 PreToolUse,
41 PostToolUse,
43 PostToolUseFailure,
45 UserPromptSubmit,
47 Stop,
49 SubagentStop,
51 PreCompact,
53 Notification,
55}
56
57pub struct HookMatcher {
61 pub event: HookEvent,
63 pub tool_name: Option<String>,
66 pub callback: HookCallback,
68 pub timeout: Option<Duration>,
70}
71
72impl std::fmt::Debug for HookMatcher {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("HookMatcher")
75 .field("event", &self.event)
76 .field("tool_name", &self.tool_name)
77 .field("timeout", &self.timeout)
78 .finish()
79 }
80}
81
82impl HookMatcher {
83 pub fn new(event: HookEvent, callback: HookCallback) -> Self {
85 Self {
86 event,
87 tool_name: None,
88 callback,
89 timeout: None,
90 }
91 }
92
93 #[must_use]
95 pub fn for_tool(mut self, name: impl Into<String>) -> Self {
96 self.tool_name = Some(name.into());
97 self
98 }
99
100 #[must_use]
102 pub fn with_timeout(mut self, timeout: Duration) -> Self {
103 self.timeout = Some(timeout);
104 self
105 }
106
107 #[must_use]
109 pub fn matches(&self, event: HookEvent, tool_name: Option<&str>) -> bool {
110 if self.event != event {
111 return false;
112 }
113 match (&self.tool_name, tool_name) {
114 (Some(filter), Some(name)) => filter == name,
115 (Some(_), None) => false,
116 (None, _) => true,
117 }
118 }
119}
120
121use crate::util::BoxFuture;
124
125pub type HookCallback =
134 Arc<dyn Fn(HookInput, Option<String>, HookContext) -> BoxFuture<HookOutput> + Send + Sync>;
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct HookInput {
139 pub hook_event: HookEvent,
141 #[serde(default, skip_serializing_if = "Option::is_none")]
143 pub tool_name: Option<String>,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
146 pub tool_input: Option<serde_json::Value>,
147 #[serde(default, skip_serializing_if = "Option::is_none")]
149 pub tool_result: Option<serde_json::Value>,
150 #[serde(default, skip_serializing_if = "Option::is_none")]
152 pub tool_use_id: Option<String>,
153 #[serde(default, skip_serializing_if = "Option::is_none")]
155 pub extra: Option<serde_json::Value>,
156}
157
158#[derive(Debug, Clone)]
160pub struct HookContext {
161 pub session_id: Option<String>,
163}
164
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
167pub struct HookOutput {
168 pub decision: HookDecision,
170 #[serde(default, skip_serializing_if = "Option::is_none")]
172 pub reason: Option<String>,
173 #[serde(default, skip_serializing_if = "Option::is_none")]
175 pub updated_input: Option<serde_json::Value>,
176 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub extra: Option<serde_json::Value>,
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case")]
184pub enum HookDecision {
185 Allow,
187 Block,
189 Modify,
191 Abort,
193}
194
195impl HookOutput {
196 #[must_use]
198 pub fn allow() -> Self {
199 Self {
200 decision: HookDecision::Allow,
201 reason: None,
202 updated_input: None,
203 extra: None,
204 }
205 }
206
207 #[must_use]
209 pub fn block(reason: impl Into<String>) -> Self {
210 Self {
211 decision: HookDecision::Block,
212 reason: Some(reason.into()),
213 updated_input: None,
214 extra: None,
215 }
216 }
217
218 #[must_use]
220 pub fn modify(updated_input: serde_json::Value) -> Self {
221 Self {
222 decision: HookDecision::Modify,
223 reason: None,
224 updated_input: Some(updated_input),
225 extra: None,
226 }
227 }
228
229 #[must_use]
231 pub fn abort(reason: impl Into<String>) -> Self {
232 Self {
233 decision: HookDecision::Abort,
234 reason: Some(reason.into()),
235 updated_input: None,
236 extra: None,
237 }
238 }
239}
240
241#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
245pub(crate) struct HookRequest {
246 pub request_id: String,
248 pub hook_event: HookEvent,
250 #[serde(default, skip_serializing_if = "Option::is_none")]
252 pub tool_name: Option<String>,
253 #[serde(default, skip_serializing_if = "Option::is_none")]
255 pub tool_input: Option<serde_json::Value>,
256 #[serde(default, skip_serializing_if = "Option::is_none")]
258 pub tool_result: Option<serde_json::Value>,
259 #[serde(default, skip_serializing_if = "Option::is_none")]
261 pub tool_use_id: Option<String>,
262}
263
264impl HookRequest {
265 #[cfg(test)]
267 pub fn into_hook_input(self) -> HookInput {
268 HookInput {
269 hook_event: self.hook_event,
270 tool_name: self.tool_name,
271 tool_input: self.tool_input,
272 tool_result: self.tool_result,
273 tool_use_id: self.tool_use_id,
274 extra: None,
275 }
276 }
277
278 pub(crate) fn to_hook_input(&self) -> HookInput {
280 HookInput {
281 hook_event: self.hook_event,
282 tool_name: self.tool_name.clone(),
283 tool_input: self.tool_input.clone(),
284 tool_result: self.tool_result.clone(),
285 tool_use_id: self.tool_use_id.clone(),
286 extra: None,
287 }
288 }
289}
290
291#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
300pub(crate) struct HookResponse {
301 pub kind: String,
303 pub request_id: String,
305 pub result: HookOutput,
307}
308
309impl HookResponse {
310 pub fn from_output(request_id: String, output: HookOutput) -> Self {
312 Self {
313 kind: "hook_response".into(),
314 request_id,
315 result: output,
316 }
317 }
318}
319
320pub(crate) async fn dispatch_hook(
328 req: &HookRequest,
329 hooks: &[HookMatcher],
330 default_hook_timeout: Duration,
331 session_id: Option<String>,
332) -> HookOutput {
333 let input = req.to_hook_input();
334
335 for matcher in hooks {
336 if !matcher.matches(req.hook_event, req.tool_name.as_deref()) {
337 continue;
338 }
339
340 let effective_timeout = matcher.timeout.unwrap_or(default_hook_timeout);
341 let ctx = HookContext {
342 session_id: session_id.clone(),
343 };
344
345 let fut = (matcher.callback)(input.clone(), session_id.clone(), ctx);
346 match tokio::time::timeout(effective_timeout, fut).await {
347 Ok(output) => return output,
348 Err(_) => {
349 tracing::warn!(
350 event = ?req.hook_event,
351 tool = ?req.tool_name,
352 timeout_secs = effective_timeout.as_secs_f64(),
353 "hook callback timed out, defaulting to allow (fail-open)"
354 );
355 return HookOutput::allow();
356 }
357 }
358 }
359
360 HookOutput::allow()
362}
363
364#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn hook_event_round_trip() {
372 let events = [
373 HookEvent::PreToolUse,
374 HookEvent::PostToolUse,
375 HookEvent::PostToolUseFailure,
376 HookEvent::UserPromptSubmit,
377 HookEvent::Stop,
378 HookEvent::SubagentStop,
379 HookEvent::PreCompact,
380 HookEvent::Notification,
381 ];
382 for event in events {
383 let json = serde_json::to_string(&event).unwrap();
384 let decoded: HookEvent = serde_json::from_str(&json).unwrap();
385 assert_eq!(event, decoded, "round-trip failed for {event:?}");
386 }
387 }
388
389 #[test]
390 fn hook_matcher_matches_any_tool() {
391 let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
392 let matcher = HookMatcher::new(HookEvent::PreToolUse, cb);
393 assert!(matcher.matches(HookEvent::PreToolUse, Some("bash")));
394 assert!(matcher.matches(HookEvent::PreToolUse, Some("read_file")));
395 assert!(matcher.matches(HookEvent::PreToolUse, None));
396 assert!(!matcher.matches(HookEvent::PostToolUse, Some("bash")));
397 }
398
399 #[test]
400 fn hook_matcher_matches_specific_tool() {
401 let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
402 let matcher = HookMatcher::new(HookEvent::PreToolUse, cb).for_tool("bash");
403 assert!(matcher.matches(HookEvent::PreToolUse, Some("bash")));
404 assert!(!matcher.matches(HookEvent::PreToolUse, Some("read_file")));
405 assert!(!matcher.matches(HookEvent::PreToolUse, None));
406 }
407
408 #[test]
409 fn hook_matcher_with_timeout() {
410 let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
411 let matcher = HookMatcher::new(HookEvent::Stop, cb).with_timeout(Duration::from_secs(5));
412 assert_eq!(matcher.timeout, Some(Duration::from_secs(5)));
413 }
414
415 #[test]
416 fn hook_output_allow() {
417 let output = HookOutput::allow();
418 assert_eq!(output.decision, HookDecision::Allow);
419 assert!(output.reason.is_none());
420 }
421
422 #[test]
423 fn hook_output_block() {
424 let output = HookOutput::block("dangerous command");
425 assert_eq!(output.decision, HookDecision::Block);
426 assert_eq!(output.reason.as_deref(), Some("dangerous command"));
427 }
428
429 #[test]
430 fn hook_output_modify() {
431 let output = HookOutput::modify(serde_json::json!({"safe": true}));
432 assert_eq!(output.decision, HookDecision::Modify);
433 assert!(output.updated_input.is_some());
434 }
435
436 #[test]
437 fn hook_output_abort() {
438 let output = HookOutput::abort("critical failure");
439 assert_eq!(output.decision, HookDecision::Abort);
440 assert_eq!(output.reason.as_deref(), Some("critical failure"));
441 }
442
443 #[test]
444 fn hook_output_round_trip() {
445 let output = HookOutput {
446 decision: HookDecision::Modify,
447 reason: Some("safety".into()),
448 updated_input: Some(serde_json::json!({"command": "ls"})),
449 extra: None,
450 };
451 let json = serde_json::to_string(&output).unwrap();
452 let decoded: HookOutput = serde_json::from_str(&json).unwrap();
453 assert_eq!(output.decision, decoded.decision);
454 assert_eq!(output.reason, decoded.reason);
455 assert_eq!(output.updated_input, decoded.updated_input);
456 }
457
458 #[test]
459 fn hook_request_round_trip() {
460 let req = HookRequest {
461 request_id: "hr-1".into(),
462 hook_event: HookEvent::PreToolUse,
463 tool_name: Some("bash".into()),
464 tool_input: Some(serde_json::json!({"command": "echo hello"})),
465 tool_result: None,
466 tool_use_id: Some("tu-1".into()),
467 };
468 let json = serde_json::to_string(&req).unwrap();
469 let decoded: HookRequest = serde_json::from_str(&json).unwrap();
470 assert_eq!(req, decoded);
471 }
472
473 #[test]
474 fn hook_request_into_hook_input() {
475 let req = HookRequest {
476 request_id: "hr-1".into(),
477 hook_event: HookEvent::PostToolUse,
478 tool_name: Some("bash".into()),
479 tool_input: None,
480 tool_result: Some(serde_json::json!("output")),
481 tool_use_id: Some("tu-1".into()),
482 };
483 let input = req.into_hook_input();
484 assert_eq!(input.hook_event, HookEvent::PostToolUse);
485 assert_eq!(input.tool_name.as_deref(), Some("bash"));
486 assert!(input.tool_result.is_some());
487 }
488
489 #[test]
490 fn hook_response_from_output() {
491 let output = HookOutput::allow();
492 let resp = HookResponse::from_output("req-1".into(), output);
493 assert_eq!(resp.kind, "hook_response");
494 assert_eq!(resp.request_id, "req-1");
495 assert_eq!(resp.result.decision, HookDecision::Allow);
496 }
497
498 #[test]
499 fn hook_response_round_trip() {
500 let resp = HookResponse {
501 kind: "hook_response".into(),
502 request_id: "hr-1".into(),
503 result: HookOutput::block("no"),
504 };
505 let json = serde_json::to_string(&resp).unwrap();
506 let decoded: HookResponse = serde_json::from_str(&json).unwrap();
507 assert_eq!(resp, decoded);
508 }
509
510 #[test]
511 fn hook_decision_serde() {
512 let decisions = [
513 (HookDecision::Allow, r#""allow""#),
514 (HookDecision::Block, r#""block""#),
515 (HookDecision::Modify, r#""modify""#),
516 (HookDecision::Abort, r#""abort""#),
517 ];
518 for (decision, expected_json) in decisions {
519 let json = serde_json::to_string(&decision).unwrap();
520 assert_eq!(json, expected_json);
521 let decoded: HookDecision = serde_json::from_str(&json).unwrap();
522 assert_eq!(decision, decoded);
523 }
524 }
525
526 #[test]
527 fn hook_input_optional_fields() {
528 let json = r#"{"hook_event":"stop"}"#;
530 let input: HookInput = serde_json::from_str(json).unwrap();
531 assert_eq!(input.hook_event, HookEvent::Stop);
532 assert!(input.tool_name.is_none());
533 assert!(input.tool_input.is_none());
534 assert!(input.tool_result.is_none());
535 }
536
537 #[tokio::test]
540 async fn hook_timeout_defaults_to_config_value() {
541 let cb: HookCallback =
542 Arc::new(|_, _, _| Box::pin(async { HookOutput::block("should arrive") }));
543 let matchers = vec![HookMatcher::new(HookEvent::PreToolUse, cb)];
544
545 let req = HookRequest {
546 request_id: "r1".into(),
547 hook_event: HookEvent::PreToolUse,
548 tool_name: Some("Bash".into()),
549 tool_input: None,
550 tool_result: None,
551 tool_use_id: None,
552 };
553
554 let output = dispatch_hook(&req, &matchers, Duration::from_secs(30), None).await;
555 assert_eq!(output.decision, HookDecision::Block);
556 }
557
558 #[tokio::test]
559 async fn hook_timeout_override() {
560 let cb: HookCallback =
561 Arc::new(|_, _, _| Box::pin(async { HookOutput::block("custom timeout") }));
562 let matchers =
563 vec![HookMatcher::new(HookEvent::PreToolUse, cb).with_timeout(Duration::from_secs(60))];
564
565 let req = HookRequest {
566 request_id: "r1".into(),
567 hook_event: HookEvent::PreToolUse,
568 tool_name: None,
569 tool_input: None,
570 tool_result: None,
571 tool_use_id: None,
572 };
573
574 let output = dispatch_hook(&req, &matchers, Duration::from_millis(1), None).await;
576 assert_eq!(output.decision, HookDecision::Block);
577 }
578
579 #[tokio::test]
580 async fn hook_timeout_fires_returns_allow() {
581 let cb: HookCallback = Arc::new(|_, _, _| {
583 Box::pin(async {
584 tokio::time::sleep(Duration::from_secs(3600)).await;
585 HookOutput::block("never reached")
586 })
587 });
588 let matchers = vec![HookMatcher::new(HookEvent::PreToolUse, cb)];
589
590 let req = HookRequest {
591 request_id: "r1".into(),
592 hook_event: HookEvent::PreToolUse,
593 tool_name: Some("Bash".into()),
594 tool_input: None,
595 tool_result: None,
596 tool_use_id: None,
597 };
598
599 let output = dispatch_hook(&req, &matchers, Duration::from_millis(10), None).await;
600 assert_eq!(output.decision, HookDecision::Allow);
602 }
603}