1use crate::types::ToolOutput;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio_util::sync::CancellationToken;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13#[non_exhaustive]
14pub enum HookEvent {
15 PreToolUse,
16 PostToolUse,
17 PostToolUseFailure,
18 UserPromptSubmit,
19 Stop,
20 SubagentStart,
21 SubagentStop,
22 PreCompact,
23 SessionStart,
24 SessionEnd,
25}
26
27impl HookEvent {
28 pub fn can_block(&self) -> bool {
42 matches!(
43 self,
44 Self::PreToolUse | Self::UserPromptSubmit | Self::SessionStart | Self::SubagentStart
45 )
46 }
47
48 pub fn from_pascal_case(s: &str) -> Option<Self> {
50 match s {
51 "PreToolUse" => Some(Self::PreToolUse),
52 "PostToolUse" => Some(Self::PostToolUse),
53 "PostToolUseFailure" => Some(Self::PostToolUseFailure),
54 "UserPromptSubmit" => Some(Self::UserPromptSubmit),
55 "Stop" => Some(Self::Stop),
56 "SubagentStart" => Some(Self::SubagentStart),
57 "SubagentStop" => Some(Self::SubagentStop),
58 "PreCompact" => Some(Self::PreCompact),
59 "SessionStart" => Some(Self::SessionStart),
60 "SessionEnd" => Some(Self::SessionEnd),
61 _ => None,
62 }
63 }
64
65 pub fn all() -> &'static [HookEvent] {
66 &[
67 Self::PreToolUse,
68 Self::PostToolUse,
69 Self::PostToolUseFailure,
70 Self::UserPromptSubmit,
71 Self::Stop,
72 Self::SubagentStart,
73 Self::SubagentStop,
74 Self::PreCompact,
75 Self::SessionStart,
76 Self::SessionEnd,
77 ]
78 }
79}
80
81impl std::fmt::Display for HookEvent {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 let s = match self {
84 Self::PreToolUse => "pre_tool_use",
85 Self::PostToolUse => "post_tool_use",
86 Self::PostToolUseFailure => "post_tool_use_failure",
87 Self::UserPromptSubmit => "user_prompt_submit",
88 Self::Stop => "stop",
89 Self::SubagentStart => "subagent_start",
90 Self::SubagentStop => "subagent_stop",
91 Self::PreCompact => "pre_compact",
92 Self::SessionStart => "session_start",
93 Self::SessionEnd => "session_end",
94 };
95 write!(f, "{}", s)
96 }
97}
98
99#[derive(Clone, Debug)]
100#[non_exhaustive]
101pub enum HookEventData {
102 PreToolUse {
103 tool_name: String,
104 tool_input: Value,
105 },
106 PostToolUse {
107 tool_name: String,
108 tool_result: ToolOutput,
109 },
110 PostToolUseFailure {
111 tool_name: String,
112 error: String,
113 },
114 UserPromptSubmit {
115 prompt: String,
116 },
117 Stop,
118 SubagentStart {
119 subagent_id: String,
120 subagent_type: String,
121 description: String,
122 },
123 SubagentStop {
124 subagent_id: String,
125 success: bool,
126 error: Option<String>,
127 },
128 PreCompact,
129 SessionStart,
130 SessionEnd,
131}
132
133impl HookEventData {
134 pub fn event_type(&self) -> HookEvent {
135 match self {
136 Self::PreToolUse { .. } => HookEvent::PreToolUse,
137 Self::PostToolUse { .. } => HookEvent::PostToolUse,
138 Self::PostToolUseFailure { .. } => HookEvent::PostToolUseFailure,
139 Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
140 Self::Stop => HookEvent::Stop,
141 Self::SubagentStart { .. } => HookEvent::SubagentStart,
142 Self::SubagentStop { .. } => HookEvent::SubagentStop,
143 Self::PreCompact => HookEvent::PreCompact,
144 Self::SessionStart => HookEvent::SessionStart,
145 Self::SessionEnd => HookEvent::SessionEnd,
146 }
147 }
148
149 pub fn tool_name(&self) -> Option<&str> {
150 match self {
151 Self::PreToolUse { tool_name, .. }
152 | Self::PostToolUse { tool_name, .. }
153 | Self::PostToolUseFailure { tool_name, .. } => Some(tool_name),
154 _ => None,
155 }
156 }
157
158 pub fn tool_input(&self) -> Option<&Value> {
159 match self {
160 Self::PreToolUse { tool_input, .. } => Some(tool_input),
161 _ => None,
162 }
163 }
164
165 pub fn subagent_id(&self) -> Option<&str> {
166 match self {
167 Self::SubagentStart { subagent_id, .. } | Self::SubagentStop { subagent_id, .. } => {
168 Some(subagent_id)
169 }
170 _ => None,
171 }
172 }
173}
174
175#[derive(Clone, Debug)]
176pub struct HookInput {
177 pub session_id: String,
178 pub timestamp: DateTime<Utc>,
179 pub data: HookEventData,
180 pub metadata: Option<Value>,
181}
182
183impl HookInput {
184 pub fn new(session_id: impl Into<String>, data: HookEventData) -> Self {
185 Self {
186 session_id: session_id.into(),
187 timestamp: Utc::now(),
188 data,
189 metadata: None,
190 }
191 }
192
193 pub fn event_type(&self) -> HookEvent {
194 self.data.event_type()
195 }
196
197 pub fn tool_name(&self) -> Option<&str> {
198 self.data.tool_name()
199 }
200
201 pub fn subagent_id(&self) -> Option<&str> {
202 self.data.subagent_id()
203 }
204
205 pub fn pre_tool_use(
206 session_id: impl Into<String>,
207 tool_name: impl Into<String>,
208 tool_input: Value,
209 ) -> Self {
210 Self::new(
211 session_id,
212 HookEventData::PreToolUse {
213 tool_name: tool_name.into(),
214 tool_input,
215 },
216 )
217 }
218
219 pub fn post_tool_use(
220 session_id: impl Into<String>,
221 tool_name: impl Into<String>,
222 tool_result: ToolOutput,
223 ) -> Self {
224 Self::new(
225 session_id,
226 HookEventData::PostToolUse {
227 tool_name: tool_name.into(),
228 tool_result,
229 },
230 )
231 }
232
233 pub fn post_tool_use_failure(
234 session_id: impl Into<String>,
235 tool_name: impl Into<String>,
236 error: impl Into<String>,
237 ) -> Self {
238 Self::new(
239 session_id,
240 HookEventData::PostToolUseFailure {
241 tool_name: tool_name.into(),
242 error: error.into(),
243 },
244 )
245 }
246
247 pub fn user_prompt_submit(session_id: impl Into<String>, prompt: impl Into<String>) -> Self {
248 Self::new(
249 session_id,
250 HookEventData::UserPromptSubmit {
251 prompt: prompt.into(),
252 },
253 )
254 }
255
256 pub fn session_start(session_id: impl Into<String>) -> Self {
257 Self::new(session_id, HookEventData::SessionStart)
258 }
259
260 pub fn session_end(session_id: impl Into<String>) -> Self {
261 Self::new(session_id, HookEventData::SessionEnd)
262 }
263
264 pub fn stop(session_id: impl Into<String>) -> Self {
265 Self::new(session_id, HookEventData::Stop)
266 }
267
268 pub fn pre_compact(session_id: impl Into<String>) -> Self {
269 Self::new(session_id, HookEventData::PreCompact)
270 }
271
272 pub fn subagent_start(
273 session_id: impl Into<String>,
274 subagent_id: impl Into<String>,
275 subagent_type: impl Into<String>,
276 description: impl Into<String>,
277 ) -> Self {
278 Self::new(
279 session_id,
280 HookEventData::SubagentStart {
281 subagent_id: subagent_id.into(),
282 subagent_type: subagent_type.into(),
283 description: description.into(),
284 },
285 )
286 }
287
288 pub fn subagent_stop(
289 session_id: impl Into<String>,
290 subagent_id: impl Into<String>,
291 success: bool,
292 error: Option<String>,
293 ) -> Self {
294 Self::new(
295 session_id,
296 HookEventData::SubagentStop {
297 subagent_id: subagent_id.into(),
298 success,
299 error,
300 },
301 )
302 }
303}
304
305#[derive(Clone, Debug, Default)]
306pub struct HookOutput {
307 pub continue_execution: bool,
308 pub stop_reason: Option<String>,
309 pub suppress_logging: bool,
310 pub system_message: Option<String>,
311 pub updated_input: Option<Value>,
312 pub additional_context: Option<String>,
313}
314
315impl HookOutput {
316 pub fn allow() -> Self {
317 Self {
318 continue_execution: true,
319 ..Default::default()
320 }
321 }
322
323 pub fn block(reason: impl Into<String>) -> Self {
324 Self {
325 continue_execution: false,
326 stop_reason: Some(reason.into()),
327 ..Default::default()
328 }
329 }
330
331 pub fn system_message(mut self, message: impl Into<String>) -> Self {
332 self.system_message = Some(message.into());
333 self
334 }
335
336 pub fn context(mut self, context: impl Into<String>) -> Self {
337 self.additional_context = Some(context.into());
338 self
339 }
340
341 pub fn updated_input(mut self, input: Value) -> Self {
342 self.updated_input = Some(input);
343 self
344 }
345
346 pub fn suppress_logging(mut self) -> Self {
347 self.suppress_logging = true;
348 self
349 }
350}
351
352#[derive(Clone, Debug)]
353pub struct HookContext {
354 pub session_id: String,
355 pub cancellation_token: CancellationToken,
356 pub cwd: Option<std::path::PathBuf>,
357 pub env: std::collections::HashMap<String, String>,
358}
359
360impl Default for HookContext {
361 fn default() -> Self {
362 Self {
363 session_id: String::new(),
364 cancellation_token: CancellationToken::new(),
365 cwd: None,
366 env: std::collections::HashMap::new(),
367 }
368 }
369}
370
371impl HookContext {
372 pub fn new(session_id: impl Into<String>) -> Self {
373 Self {
374 session_id: session_id.into(),
375 ..Default::default()
376 }
377 }
378
379 pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
380 self.cancellation_token = token;
381 self
382 }
383
384 pub fn cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
385 self.cwd = Some(cwd.into());
386 self
387 }
388
389 pub fn env(mut self, env: std::collections::HashMap<String, String>) -> Self {
390 self.env = env;
391 self
392 }
393}
394
395#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
397pub enum HookSource {
398 #[default]
399 Builtin,
400 User,
401 Project,
402}
403
404#[derive(Clone, Debug)]
406pub struct HookMetadata {
407 pub name: String,
408 pub events: Vec<HookEvent>,
409 pub priority: i32,
410 pub timeout_secs: u64,
411 pub tool_matcher: Option<Regex>,
412 pub source: HookSource,
413}
414
415impl HookMetadata {
416 pub fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
417 Self {
418 name: name.into(),
419 events,
420 priority: 0,
421 timeout_secs: 60,
422 tool_matcher: None,
423 source: HookSource::default(),
424 }
425 }
426
427 pub fn priority(mut self, priority: i32) -> Self {
428 self.priority = priority;
429 self
430 }
431
432 pub fn timeout(mut self, secs: u64) -> Self {
433 self.timeout_secs = secs;
434 self
435 }
436
437 pub fn tool_matcher(mut self, pattern: &str) -> Self {
438 if let Ok(regex) = Regex::new(pattern) {
439 self.tool_matcher = Some(regex);
440 }
441 self
442 }
443
444 pub fn source(mut self, source: HookSource) -> Self {
445 self.source = source;
446 self
447 }
448}
449
450#[async_trait]
451pub trait Hook: Send + Sync {
452 fn name(&self) -> &str;
453 fn events(&self) -> &[HookEvent];
454
455 #[inline]
456 fn tool_matcher(&self) -> Option<&Regex> {
457 None
458 }
459
460 #[inline]
461 fn timeout_secs(&self) -> u64 {
462 60
463 }
464
465 #[inline]
466 fn priority(&self) -> i32 {
467 0
468 }
469
470 async fn execute(
471 &self,
472 input: HookInput,
473 hook_context: &HookContext,
474 ) -> Result<HookOutput, crate::Error>;
475
476 #[inline]
477 fn source(&self) -> HookSource {
478 HookSource::Builtin
479 }
480
481 fn metadata(&self) -> HookMetadata {
482 HookMetadata {
483 name: self.name().to_string(),
484 events: self.events().to_vec(),
485 priority: self.priority(),
486 timeout_secs: self.timeout_secs(),
487 tool_matcher: self.tool_matcher().cloned(),
488 source: self.source(),
489 }
490 }
491}
492
493pub struct FnHook<F> {
494 name: String,
495 events: Vec<HookEvent>,
496 handler: F,
497 priority: i32,
498 timeout_secs: u64,
499 tool_matcher: Option<Regex>,
500}
501
502impl<F> FnHook<F> {
503 pub fn builder(name: impl Into<String>, events: Vec<HookEvent>) -> FnHookBuilder {
504 FnHookBuilder {
505 name: name.into(),
506 events,
507 priority: 0,
508 timeout_secs: 60,
509 tool_matcher: None,
510 }
511 }
512}
513
514pub struct FnHookBuilder {
515 name: String,
516 events: Vec<HookEvent>,
517 priority: i32,
518 timeout_secs: u64,
519 tool_matcher: Option<Regex>,
520}
521
522impl FnHookBuilder {
523 pub fn priority(mut self, priority: i32) -> Self {
524 self.priority = priority;
525 self
526 }
527
528 pub fn timeout_secs(mut self, secs: u64) -> Self {
529 self.timeout_secs = secs;
530 self
531 }
532
533 pub fn tool_matcher(mut self, pattern: &str) -> Self {
534 if let Ok(regex) = Regex::new(pattern) {
535 self.tool_matcher = Some(regex);
536 }
537 self
538 }
539
540 pub fn handler<F, Fut>(self, handler: F) -> FnHook<F>
541 where
542 F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
543 Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
544 {
545 FnHook {
546 name: self.name,
547 events: self.events,
548 handler,
549 priority: self.priority,
550 timeout_secs: self.timeout_secs,
551 tool_matcher: self.tool_matcher,
552 }
553 }
554}
555
556#[async_trait]
557impl<F, Fut> Hook for FnHook<F>
558where
559 F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
560 Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
561{
562 fn name(&self) -> &str {
563 &self.name
564 }
565
566 fn events(&self) -> &[HookEvent] {
567 &self.events
568 }
569
570 fn priority(&self) -> i32 {
571 self.priority
572 }
573
574 fn timeout_secs(&self) -> u64 {
575 self.timeout_secs
576 }
577
578 fn tool_matcher(&self) -> Option<&Regex> {
579 self.tool_matcher.as_ref()
580 }
581
582 async fn execute(
583 &self,
584 input: HookInput,
585 hook_context: &HookContext,
586 ) -> Result<HookOutput, crate::Error> {
587 (self.handler)(input, hook_context.clone()).await
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_hook_event_display() {
597 assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
598 assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
599 assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
600 }
601
602 #[test]
603 fn test_hook_event_can_block() {
604 assert!(HookEvent::PreToolUse.can_block());
606 assert!(HookEvent::UserPromptSubmit.can_block());
607 assert!(HookEvent::SessionStart.can_block());
608 assert!(!HookEvent::PreCompact.can_block());
609 assert!(HookEvent::SubagentStart.can_block());
610
611 assert!(!HookEvent::PostToolUse.can_block());
613 assert!(!HookEvent::PostToolUseFailure.can_block());
614 assert!(!HookEvent::SessionEnd.can_block());
615 assert!(!HookEvent::SubagentStop.can_block());
616 assert!(!HookEvent::Stop.can_block());
617 }
618
619 #[test]
620 fn test_hook_input_builders() {
621 let input =
622 HookInput::pre_tool_use("session-1", "Read", serde_json::json!({"path": "/tmp"}));
623 assert_eq!(input.event_type(), HookEvent::PreToolUse);
624 assert_eq!(input.tool_name(), Some("Read"));
625 assert_eq!(input.session_id, "session-1");
626
627 let input = HookInput::session_start("session-2");
628 assert_eq!(input.event_type(), HookEvent::SessionStart);
629 assert_eq!(input.session_id, "session-2");
630 }
631
632 #[test]
633 fn test_hook_output_builders() {
634 let output = HookOutput::allow();
635 assert!(output.continue_execution);
636 assert!(output.stop_reason.is_none());
637
638 let output = HookOutput::block("Dangerous operation");
639 assert!(!output.continue_execution);
640 assert_eq!(output.stop_reason, Some("Dangerous operation".to_string()));
641
642 let output = HookOutput::allow()
643 .system_message("Added context")
644 .context("More info")
645 .suppress_logging();
646 assert!(output.continue_execution);
647 assert!(output.suppress_logging);
648 assert_eq!(output.system_message, Some("Added context".to_string()));
649 assert_eq!(output.additional_context, Some("More info".to_string()));
650 }
651
652 #[test]
653 fn test_hook_event_data_accessors() {
654 let data = HookEventData::PreToolUse {
655 tool_name: "Bash".to_string(),
656 tool_input: serde_json::json!({"command": "ls"}),
657 };
658 assert_eq!(data.event_type(), HookEvent::PreToolUse);
659 assert_eq!(data.tool_name(), Some("Bash"));
660 assert!(data.tool_input().is_some());
661
662 let data = HookEventData::SessionStart;
663 assert_eq!(data.event_type(), HookEvent::SessionStart);
664 assert_eq!(data.tool_name(), None);
665 assert!(data.tool_input().is_none());
666 }
667}