1use crate::types::ToolOutput;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio_util::sync::CancellationToken;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13#[non_exhaustive]
14pub enum HookEvent {
15 PreToolUse,
16 PostToolUse,
17 PostToolUseFailure,
18 UserPromptSubmit,
19 Stop,
20 SubagentStart,
21 SubagentStop,
22 PreCompact,
23 SessionStart,
24 SessionEnd,
25}
26
27impl HookEvent {
28 pub fn can_block(&self) -> bool {
40 matches!(
41 self,
42 Self::PreToolUse
43 | Self::UserPromptSubmit
44 | Self::SessionStart
45 | Self::PreCompact
46 | Self::SubagentStart
47 )
48 }
49
50 pub fn from_pascal_case(s: &str) -> Option<Self> {
52 match s {
53 "PreToolUse" => Some(Self::PreToolUse),
54 "PostToolUse" => Some(Self::PostToolUse),
55 "PostToolUseFailure" => Some(Self::PostToolUseFailure),
56 "UserPromptSubmit" => Some(Self::UserPromptSubmit),
57 "Stop" => Some(Self::Stop),
58 "SubagentStart" => Some(Self::SubagentStart),
59 "SubagentStop" => Some(Self::SubagentStop),
60 "PreCompact" => Some(Self::PreCompact),
61 "SessionStart" => Some(Self::SessionStart),
62 "SessionEnd" => Some(Self::SessionEnd),
63 _ => None,
64 }
65 }
66
67 pub fn all() -> &'static [HookEvent] {
68 &[
69 Self::PreToolUse,
70 Self::PostToolUse,
71 Self::PostToolUseFailure,
72 Self::UserPromptSubmit,
73 Self::Stop,
74 Self::SubagentStart,
75 Self::SubagentStop,
76 Self::PreCompact,
77 Self::SessionStart,
78 Self::SessionEnd,
79 ]
80 }
81}
82
83impl std::fmt::Display for HookEvent {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 let s = match self {
86 Self::PreToolUse => "pre_tool_use",
87 Self::PostToolUse => "post_tool_use",
88 Self::PostToolUseFailure => "post_tool_use_failure",
89 Self::UserPromptSubmit => "user_prompt_submit",
90 Self::Stop => "stop",
91 Self::SubagentStart => "subagent_start",
92 Self::SubagentStop => "subagent_stop",
93 Self::PreCompact => "pre_compact",
94 Self::SessionStart => "session_start",
95 Self::SessionEnd => "session_end",
96 };
97 write!(f, "{}", s)
98 }
99}
100
101#[derive(Clone, Debug)]
102#[non_exhaustive]
103pub enum HookEventData {
104 PreToolUse {
105 tool_name: String,
106 tool_input: Value,
107 },
108 PostToolUse {
109 tool_name: String,
110 tool_result: ToolOutput,
111 },
112 PostToolUseFailure {
113 tool_name: String,
114 error: String,
115 },
116 UserPromptSubmit {
117 prompt: String,
118 },
119 Stop,
120 SubagentStart {
121 subagent_id: String,
122 subagent_type: String,
123 description: String,
124 },
125 SubagentStop {
126 subagent_id: String,
127 success: bool,
128 error: Option<String>,
129 },
130 PreCompact,
131 SessionStart,
132 SessionEnd,
133}
134
135impl HookEventData {
136 pub fn event_type(&self) -> HookEvent {
137 match self {
138 Self::PreToolUse { .. } => HookEvent::PreToolUse,
139 Self::PostToolUse { .. } => HookEvent::PostToolUse,
140 Self::PostToolUseFailure { .. } => HookEvent::PostToolUseFailure,
141 Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
142 Self::Stop => HookEvent::Stop,
143 Self::SubagentStart { .. } => HookEvent::SubagentStart,
144 Self::SubagentStop { .. } => HookEvent::SubagentStop,
145 Self::PreCompact => HookEvent::PreCompact,
146 Self::SessionStart => HookEvent::SessionStart,
147 Self::SessionEnd => HookEvent::SessionEnd,
148 }
149 }
150
151 pub fn tool_name(&self) -> Option<&str> {
152 match self {
153 Self::PreToolUse { tool_name, .. }
154 | Self::PostToolUse { tool_name, .. }
155 | Self::PostToolUseFailure { tool_name, .. } => Some(tool_name),
156 _ => None,
157 }
158 }
159
160 pub fn tool_input(&self) -> Option<&Value> {
161 match self {
162 Self::PreToolUse { tool_input, .. } => Some(tool_input),
163 _ => None,
164 }
165 }
166
167 pub fn subagent_id(&self) -> Option<&str> {
168 match self {
169 Self::SubagentStart { subagent_id, .. } | Self::SubagentStop { subagent_id, .. } => {
170 Some(subagent_id)
171 }
172 _ => None,
173 }
174 }
175}
176
177#[derive(Clone, Debug)]
178pub struct HookInput {
179 pub session_id: String,
180 pub timestamp: DateTime<Utc>,
181 pub data: HookEventData,
182 pub metadata: Option<Value>,
183}
184
185impl HookInput {
186 pub fn new(session_id: impl Into<String>, data: HookEventData) -> Self {
187 Self {
188 session_id: session_id.into(),
189 timestamp: Utc::now(),
190 data,
191 metadata: None,
192 }
193 }
194
195 pub fn event_type(&self) -> HookEvent {
196 self.data.event_type()
197 }
198
199 pub fn tool_name(&self) -> Option<&str> {
200 self.data.tool_name()
201 }
202
203 pub fn subagent_id(&self) -> Option<&str> {
204 self.data.subagent_id()
205 }
206
207 pub fn pre_tool_use(
208 session_id: impl Into<String>,
209 tool_name: impl Into<String>,
210 tool_input: Value,
211 ) -> Self {
212 Self::new(
213 session_id,
214 HookEventData::PreToolUse {
215 tool_name: tool_name.into(),
216 tool_input,
217 },
218 )
219 }
220
221 pub fn post_tool_use(
222 session_id: impl Into<String>,
223 tool_name: impl Into<String>,
224 tool_result: ToolOutput,
225 ) -> Self {
226 Self::new(
227 session_id,
228 HookEventData::PostToolUse {
229 tool_name: tool_name.into(),
230 tool_result,
231 },
232 )
233 }
234
235 pub fn post_tool_use_failure(
236 session_id: impl Into<String>,
237 tool_name: impl Into<String>,
238 error: impl Into<String>,
239 ) -> Self {
240 Self::new(
241 session_id,
242 HookEventData::PostToolUseFailure {
243 tool_name: tool_name.into(),
244 error: error.into(),
245 },
246 )
247 }
248
249 pub fn user_prompt_submit(session_id: impl Into<String>, prompt: impl Into<String>) -> Self {
250 Self::new(
251 session_id,
252 HookEventData::UserPromptSubmit {
253 prompt: prompt.into(),
254 },
255 )
256 }
257
258 pub fn session_start(session_id: impl Into<String>) -> Self {
259 Self::new(session_id, HookEventData::SessionStart)
260 }
261
262 pub fn session_end(session_id: impl Into<String>) -> Self {
263 Self::new(session_id, HookEventData::SessionEnd)
264 }
265
266 pub fn stop(session_id: impl Into<String>) -> Self {
267 Self::new(session_id, HookEventData::Stop)
268 }
269
270 pub fn pre_compact(session_id: impl Into<String>) -> Self {
271 Self::new(session_id, HookEventData::PreCompact)
272 }
273
274 pub fn subagent_start(
275 session_id: impl Into<String>,
276 subagent_id: impl Into<String>,
277 subagent_type: impl Into<String>,
278 description: impl Into<String>,
279 ) -> Self {
280 Self::new(
281 session_id,
282 HookEventData::SubagentStart {
283 subagent_id: subagent_id.into(),
284 subagent_type: subagent_type.into(),
285 description: description.into(),
286 },
287 )
288 }
289
290 pub fn subagent_stop(
291 session_id: impl Into<String>,
292 subagent_id: impl Into<String>,
293 success: bool,
294 error: Option<String>,
295 ) -> Self {
296 Self::new(
297 session_id,
298 HookEventData::SubagentStop {
299 subagent_id: subagent_id.into(),
300 success,
301 error,
302 },
303 )
304 }
305}
306
307#[derive(Clone, Debug, Default)]
308pub struct HookOutput {
309 pub continue_execution: bool,
310 pub stop_reason: Option<String>,
311 pub suppress_logging: bool,
312 pub system_message: Option<String>,
313 pub updated_input: Option<Value>,
314 pub additional_context: Option<String>,
315}
316
317impl HookOutput {
318 pub fn allow() -> Self {
319 Self {
320 continue_execution: true,
321 ..Default::default()
322 }
323 }
324
325 pub fn block(reason: impl Into<String>) -> Self {
326 Self {
327 continue_execution: false,
328 stop_reason: Some(reason.into()),
329 ..Default::default()
330 }
331 }
332
333 pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
334 self.system_message = Some(message.into());
335 self
336 }
337
338 pub fn with_context(mut self, context: impl Into<String>) -> Self {
339 self.additional_context = Some(context.into());
340 self
341 }
342
343 pub fn with_updated_input(mut self, input: Value) -> Self {
344 self.updated_input = Some(input);
345 self
346 }
347
348 pub fn suppress_logging(mut self) -> Self {
349 self.suppress_logging = true;
350 self
351 }
352}
353
354#[derive(Clone, Debug)]
355pub struct HookContext {
356 pub session_id: String,
357 pub cancellation_token: CancellationToken,
358 pub cwd: Option<std::path::PathBuf>,
359 pub env: std::collections::HashMap<String, String>,
360}
361
362impl Default for HookContext {
363 fn default() -> Self {
364 Self {
365 session_id: String::new(),
366 cancellation_token: CancellationToken::new(),
367 cwd: None,
368 env: std::collections::HashMap::new(),
369 }
370 }
371}
372
373impl HookContext {
374 pub fn new(session_id: impl Into<String>) -> Self {
375 Self {
376 session_id: session_id.into(),
377 ..Default::default()
378 }
379 }
380
381 pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
382 self.cancellation_token = token;
383 self
384 }
385
386 pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
387 self.cwd = Some(cwd.into());
388 self
389 }
390
391 pub fn with_env(mut self, env: std::collections::HashMap<String, String>) -> Self {
392 self.env = env;
393 self
394 }
395}
396
397#[derive(Clone, Debug)]
399pub struct HookMetadata {
400 pub name: String,
401 pub events: Vec<HookEvent>,
402 pub priority: i32,
403 pub timeout_secs: u64,
404 pub tool_matcher: Option<Regex>,
405}
406
407impl HookMetadata {
408 pub fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
409 Self {
410 name: name.into(),
411 events,
412 priority: 0,
413 timeout_secs: 60,
414 tool_matcher: None,
415 }
416 }
417
418 pub fn with_priority(mut self, priority: i32) -> Self {
419 self.priority = priority;
420 self
421 }
422
423 pub fn with_timeout(mut self, secs: u64) -> Self {
424 self.timeout_secs = secs;
425 self
426 }
427
428 pub fn with_tool_matcher(mut self, pattern: &str) -> Self {
429 if let Ok(regex) = Regex::new(pattern) {
430 self.tool_matcher = Some(regex);
431 }
432 self
433 }
434}
435
436#[async_trait]
437pub trait Hook: Send + Sync {
438 fn name(&self) -> &str;
439 fn events(&self) -> &[HookEvent];
440
441 #[inline]
442 fn tool_matcher(&self) -> Option<&Regex> {
443 None
444 }
445
446 #[inline]
447 fn timeout_secs(&self) -> u64 {
448 60
449 }
450
451 #[inline]
452 fn priority(&self) -> i32 {
453 0
454 }
455
456 async fn execute(
457 &self,
458 input: HookInput,
459 hook_context: &HookContext,
460 ) -> Result<HookOutput, crate::Error>;
461
462 fn metadata(&self) -> HookMetadata {
464 HookMetadata {
465 name: self.name().to_string(),
466 events: self.events().to_vec(),
467 priority: self.priority(),
468 timeout_secs: self.timeout_secs(),
469 tool_matcher: self.tool_matcher().cloned(),
470 }
471 }
472}
473
474pub struct FnHook<F> {
475 name: String,
476 events: Vec<HookEvent>,
477 handler: F,
478 priority: i32,
479 timeout_secs: u64,
480 tool_matcher: Option<Regex>,
481}
482
483impl<F> FnHook<F> {
484 pub fn builder(name: impl Into<String>, events: Vec<HookEvent>) -> FnHookBuilder {
485 FnHookBuilder {
486 name: name.into(),
487 events,
488 priority: 0,
489 timeout_secs: 60,
490 tool_matcher: None,
491 }
492 }
493}
494
495pub struct FnHookBuilder {
496 name: String,
497 events: Vec<HookEvent>,
498 priority: i32,
499 timeout_secs: u64,
500 tool_matcher: Option<Regex>,
501}
502
503impl FnHookBuilder {
504 pub fn priority(mut self, priority: i32) -> Self {
505 self.priority = priority;
506 self
507 }
508
509 pub fn timeout_secs(mut self, secs: u64) -> Self {
510 self.timeout_secs = secs;
511 self
512 }
513
514 pub fn tool_matcher(mut self, pattern: &str) -> Self {
515 if let Ok(regex) = Regex::new(pattern) {
516 self.tool_matcher = Some(regex);
517 }
518 self
519 }
520
521 pub fn handler<F, Fut>(self, handler: F) -> FnHook<F>
522 where
523 F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
524 Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
525 {
526 FnHook {
527 name: self.name,
528 events: self.events,
529 handler,
530 priority: self.priority,
531 timeout_secs: self.timeout_secs,
532 tool_matcher: self.tool_matcher,
533 }
534 }
535}
536
537#[async_trait]
538impl<F, Fut> Hook for FnHook<F>
539where
540 F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
541 Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
542{
543 fn name(&self) -> &str {
544 &self.name
545 }
546
547 fn events(&self) -> &[HookEvent] {
548 &self.events
549 }
550
551 fn priority(&self) -> i32 {
552 self.priority
553 }
554
555 fn timeout_secs(&self) -> u64 {
556 self.timeout_secs
557 }
558
559 fn tool_matcher(&self) -> Option<&Regex> {
560 self.tool_matcher.as_ref()
561 }
562
563 async fn execute(
564 &self,
565 input: HookInput,
566 hook_context: &HookContext,
567 ) -> Result<HookOutput, crate::Error> {
568 (self.handler)(input, hook_context.clone()).await
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575
576 #[test]
577 fn test_hook_event_display() {
578 assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
579 assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
580 assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
581 }
582
583 #[test]
584 fn test_hook_event_can_block() {
585 assert!(HookEvent::PreToolUse.can_block());
587 assert!(HookEvent::UserPromptSubmit.can_block());
588 assert!(HookEvent::SessionStart.can_block());
589 assert!(HookEvent::PreCompact.can_block());
590 assert!(HookEvent::SubagentStart.can_block());
591
592 assert!(!HookEvent::PostToolUse.can_block());
594 assert!(!HookEvent::PostToolUseFailure.can_block());
595 assert!(!HookEvent::SessionEnd.can_block());
596 assert!(!HookEvent::SubagentStop.can_block());
597 assert!(!HookEvent::Stop.can_block());
598 }
599
600 #[test]
601 fn test_hook_input_builders() {
602 let input =
603 HookInput::pre_tool_use("session-1", "Read", serde_json::json!({"path": "/tmp"}));
604 assert_eq!(input.event_type(), HookEvent::PreToolUse);
605 assert_eq!(input.tool_name(), Some("Read"));
606 assert_eq!(input.session_id, "session-1");
607
608 let input = HookInput::session_start("session-2");
609 assert_eq!(input.event_type(), HookEvent::SessionStart);
610 assert_eq!(input.session_id, "session-2");
611 }
612
613 #[test]
614 fn test_hook_output_builders() {
615 let output = HookOutput::allow();
616 assert!(output.continue_execution);
617 assert!(output.stop_reason.is_none());
618
619 let output = HookOutput::block("Dangerous operation");
620 assert!(!output.continue_execution);
621 assert_eq!(output.stop_reason, Some("Dangerous operation".to_string()));
622
623 let output = HookOutput::allow()
624 .with_system_message("Added context")
625 .with_context("More info")
626 .suppress_logging();
627 assert!(output.continue_execution);
628 assert!(output.suppress_logging);
629 assert_eq!(output.system_message, Some("Added context".to_string()));
630 assert_eq!(output.additional_context, Some("More info".to_string()));
631 }
632
633 #[test]
634 fn test_hook_event_data_accessors() {
635 let data = HookEventData::PreToolUse {
636 tool_name: "Bash".to_string(),
637 tool_input: serde_json::json!({"command": "ls"}),
638 };
639 assert_eq!(data.event_type(), HookEvent::PreToolUse);
640 assert_eq!(data.tool_name(), Some("Bash"));
641 assert!(data.tool_input().is_some());
642
643 let data = HookEventData::SessionStart;
644 assert_eq!(data.event_type(), HookEvent::SessionStart);
645 assert_eq!(data.tool_name(), None);
646 assert!(data.tool_input().is_none());
647 }
648}