1use crate::types::ToolOutput;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio_util::sync::CancellationToken;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13#[non_exhaustive]
14pub enum HookEvent {
15 PreToolUse,
16 PostToolUse,
17 PostToolUseFailure,
18 UserPromptSubmit,
19 Stop,
20 SubagentStart,
21 SubagentStop,
22 PreCompact,
23 SessionStart,
24 SessionEnd,
25}
26
27impl HookEvent {
28 pub fn can_block(&self) -> bool {
40 matches!(
41 self,
42 Self::PreToolUse
43 | Self::UserPromptSubmit
44 | Self::SessionStart
45 | Self::PreCompact
46 | Self::SubagentStart
47 )
48 }
49
50 pub fn can_modify_input(&self) -> bool {
51 matches!(self, Self::PreToolUse | Self::UserPromptSubmit)
52 }
53
54 pub fn all() -> &'static [HookEvent] {
55 &[
56 Self::PreToolUse,
57 Self::PostToolUse,
58 Self::PostToolUseFailure,
59 Self::UserPromptSubmit,
60 Self::Stop,
61 Self::SubagentStart,
62 Self::SubagentStop,
63 Self::PreCompact,
64 Self::SessionStart,
65 Self::SessionEnd,
66 ]
67 }
68}
69
70impl std::fmt::Display for HookEvent {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 let s = match self {
73 Self::PreToolUse => "pre_tool_use",
74 Self::PostToolUse => "post_tool_use",
75 Self::PostToolUseFailure => "post_tool_use_failure",
76 Self::UserPromptSubmit => "user_prompt_submit",
77 Self::Stop => "stop",
78 Self::SubagentStart => "subagent_start",
79 Self::SubagentStop => "subagent_stop",
80 Self::PreCompact => "pre_compact",
81 Self::SessionStart => "session_start",
82 Self::SessionEnd => "session_end",
83 };
84 write!(f, "{}", s)
85 }
86}
87
88#[derive(Clone, Debug)]
89#[non_exhaustive]
90pub enum HookEventData {
91 PreToolUse {
92 tool_name: String,
93 tool_input: Value,
94 },
95 PostToolUse {
96 tool_name: String,
97 tool_result: ToolOutput,
98 },
99 PostToolUseFailure {
100 tool_name: String,
101 error: String,
102 },
103 UserPromptSubmit {
104 prompt: String,
105 },
106 Stop,
107 SubagentStart {
108 subagent_id: String,
109 subagent_type: String,
110 description: String,
111 },
112 SubagentStop {
113 subagent_id: String,
114 success: bool,
115 error: Option<String>,
116 },
117 PreCompact,
118 SessionStart,
119 SessionEnd,
120}
121
122impl HookEventData {
123 pub fn event_type(&self) -> HookEvent {
124 match self {
125 Self::PreToolUse { .. } => HookEvent::PreToolUse,
126 Self::PostToolUse { .. } => HookEvent::PostToolUse,
127 Self::PostToolUseFailure { .. } => HookEvent::PostToolUseFailure,
128 Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
129 Self::Stop => HookEvent::Stop,
130 Self::SubagentStart { .. } => HookEvent::SubagentStart,
131 Self::SubagentStop { .. } => HookEvent::SubagentStop,
132 Self::PreCompact => HookEvent::PreCompact,
133 Self::SessionStart => HookEvent::SessionStart,
134 Self::SessionEnd => HookEvent::SessionEnd,
135 }
136 }
137
138 pub fn tool_name(&self) -> Option<&str> {
139 match self {
140 Self::PreToolUse { tool_name, .. }
141 | Self::PostToolUse { tool_name, .. }
142 | Self::PostToolUseFailure { tool_name, .. } => Some(tool_name),
143 _ => None,
144 }
145 }
146
147 pub fn tool_input(&self) -> Option<&Value> {
148 match self {
149 Self::PreToolUse { tool_input, .. } => Some(tool_input),
150 _ => None,
151 }
152 }
153
154 pub fn subagent_id(&self) -> Option<&str> {
155 match self {
156 Self::SubagentStart { subagent_id, .. } | Self::SubagentStop { subagent_id, .. } => {
157 Some(subagent_id)
158 }
159 _ => None,
160 }
161 }
162}
163
164#[derive(Clone, Debug)]
165pub struct HookInput {
166 pub session_id: String,
167 pub timestamp: DateTime<Utc>,
168 pub data: HookEventData,
169 pub metadata: Option<Value>,
170}
171
172impl HookInput {
173 pub fn new(session_id: impl Into<String>, data: HookEventData) -> Self {
174 Self {
175 session_id: session_id.into(),
176 timestamp: Utc::now(),
177 data,
178 metadata: None,
179 }
180 }
181
182 pub fn with_metadata(mut self, metadata: Value) -> Self {
183 self.metadata = Some(metadata);
184 self
185 }
186
187 pub fn event_type(&self) -> HookEvent {
188 self.data.event_type()
189 }
190
191 pub fn tool_name(&self) -> Option<&str> {
192 self.data.tool_name()
193 }
194
195 pub fn subagent_id(&self) -> Option<&str> {
196 self.data.subagent_id()
197 }
198
199 pub fn pre_tool_use(
200 session_id: impl Into<String>,
201 tool_name: impl Into<String>,
202 tool_input: Value,
203 ) -> Self {
204 Self::new(
205 session_id,
206 HookEventData::PreToolUse {
207 tool_name: tool_name.into(),
208 tool_input,
209 },
210 )
211 }
212
213 pub fn post_tool_use(
214 session_id: impl Into<String>,
215 tool_name: impl Into<String>,
216 tool_result: ToolOutput,
217 ) -> Self {
218 Self::new(
219 session_id,
220 HookEventData::PostToolUse {
221 tool_name: tool_name.into(),
222 tool_result,
223 },
224 )
225 }
226
227 pub fn post_tool_use_failure(
228 session_id: impl Into<String>,
229 tool_name: impl Into<String>,
230 error: impl Into<String>,
231 ) -> Self {
232 Self::new(
233 session_id,
234 HookEventData::PostToolUseFailure {
235 tool_name: tool_name.into(),
236 error: error.into(),
237 },
238 )
239 }
240
241 pub fn user_prompt_submit(session_id: impl Into<String>, prompt: impl Into<String>) -> Self {
242 Self::new(
243 session_id,
244 HookEventData::UserPromptSubmit {
245 prompt: prompt.into(),
246 },
247 )
248 }
249
250 pub fn session_start(session_id: impl Into<String>) -> Self {
251 Self::new(session_id, HookEventData::SessionStart)
252 }
253
254 pub fn session_end(session_id: impl Into<String>) -> Self {
255 Self::new(session_id, HookEventData::SessionEnd)
256 }
257
258 pub fn stop(session_id: impl Into<String>) -> Self {
259 Self::new(session_id, HookEventData::Stop)
260 }
261
262 pub fn pre_compact(session_id: impl Into<String>) -> Self {
263 Self::new(session_id, HookEventData::PreCompact)
264 }
265
266 pub fn subagent_start(
267 session_id: impl Into<String>,
268 subagent_id: impl Into<String>,
269 subagent_type: impl Into<String>,
270 description: impl Into<String>,
271 ) -> Self {
272 Self::new(
273 session_id,
274 HookEventData::SubagentStart {
275 subagent_id: subagent_id.into(),
276 subagent_type: subagent_type.into(),
277 description: description.into(),
278 },
279 )
280 }
281
282 pub fn subagent_stop(
283 session_id: impl Into<String>,
284 subagent_id: impl Into<String>,
285 success: bool,
286 error: Option<String>,
287 ) -> Self {
288 Self::new(
289 session_id,
290 HookEventData::SubagentStop {
291 subagent_id: subagent_id.into(),
292 success,
293 error,
294 },
295 )
296 }
297}
298
299#[derive(Clone, Debug, Default)]
300pub struct HookOutput {
301 pub continue_execution: bool,
302 pub stop_reason: Option<String>,
303 pub suppress_logging: bool,
304 pub system_message: Option<String>,
305 pub updated_input: Option<Value>,
306 pub additional_context: Option<String>,
307}
308
309impl HookOutput {
310 pub fn allow() -> Self {
311 Self {
312 continue_execution: true,
313 ..Default::default()
314 }
315 }
316
317 pub fn block(reason: impl Into<String>) -> Self {
318 Self {
319 continue_execution: false,
320 stop_reason: Some(reason.into()),
321 ..Default::default()
322 }
323 }
324
325 pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
326 self.system_message = Some(message.into());
327 self
328 }
329
330 pub fn with_context(mut self, context: impl Into<String>) -> Self {
331 self.additional_context = Some(context.into());
332 self
333 }
334
335 pub fn with_updated_input(mut self, input: Value) -> Self {
336 self.updated_input = Some(input);
337 self
338 }
339
340 pub fn suppress_logging(mut self) -> Self {
341 self.suppress_logging = true;
342 self
343 }
344}
345
346#[derive(Clone, Debug)]
347pub struct HookContext {
348 pub session_id: String,
349 pub cancellation_token: CancellationToken,
350 pub cwd: Option<std::path::PathBuf>,
351 pub env: std::collections::HashMap<String, String>,
352}
353
354impl Default for HookContext {
355 fn default() -> Self {
356 Self {
357 session_id: String::new(),
358 cancellation_token: CancellationToken::new(),
359 cwd: None,
360 env: std::collections::HashMap::new(),
361 }
362 }
363}
364
365impl HookContext {
366 pub fn new(session_id: impl Into<String>) -> Self {
367 Self {
368 session_id: session_id.into(),
369 ..Default::default()
370 }
371 }
372
373 pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
374 self.cancellation_token = token;
375 self
376 }
377
378 pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
379 self.cwd = Some(cwd.into());
380 self
381 }
382
383 pub fn with_env(mut self, env: std::collections::HashMap<String, String>) -> Self {
384 self.env = env;
385 self
386 }
387}
388
389#[derive(Clone, Debug)]
391pub struct HookMetadata {
392 pub name: String,
393 pub events: Vec<HookEvent>,
394 pub priority: i32,
395 pub timeout_secs: u64,
396 pub tool_matcher: Option<Regex>,
397}
398
399impl HookMetadata {
400 pub fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
401 Self {
402 name: name.into(),
403 events,
404 priority: 0,
405 timeout_secs: 60,
406 tool_matcher: None,
407 }
408 }
409
410 pub fn with_priority(mut self, priority: i32) -> Self {
411 self.priority = priority;
412 self
413 }
414
415 pub fn with_timeout(mut self, secs: u64) -> Self {
416 self.timeout_secs = secs;
417 self
418 }
419
420 pub fn with_tool_matcher(mut self, pattern: &str) -> Self {
421 if let Ok(regex) = Regex::new(pattern) {
422 self.tool_matcher = Some(regex);
423 }
424 self
425 }
426}
427
428#[async_trait]
429pub trait Hook: Send + Sync {
430 fn name(&self) -> &str;
431 fn events(&self) -> &[HookEvent];
432
433 #[inline]
434 fn tool_matcher(&self) -> Option<&Regex> {
435 None
436 }
437
438 #[inline]
439 fn timeout_secs(&self) -> u64 {
440 60
441 }
442
443 #[inline]
444 fn priority(&self) -> i32 {
445 0
446 }
447
448 async fn execute(
449 &self,
450 input: HookInput,
451 hook_context: &HookContext,
452 ) -> Result<HookOutput, crate::Error>;
453
454 fn metadata(&self) -> HookMetadata {
456 HookMetadata {
457 name: self.name().to_string(),
458 events: self.events().to_vec(),
459 priority: self.priority(),
460 timeout_secs: self.timeout_secs(),
461 tool_matcher: self.tool_matcher().cloned(),
462 }
463 }
464}
465
466pub struct FnHook<F> {
467 name: String,
468 events: Vec<HookEvent>,
469 handler: F,
470 priority: i32,
471 timeout_secs: u64,
472 tool_matcher: Option<Regex>,
473}
474
475impl<F> FnHook<F> {
476 pub fn builder(name: impl Into<String>, events: Vec<HookEvent>) -> FnHookBuilder {
477 FnHookBuilder {
478 name: name.into(),
479 events,
480 priority: 0,
481 timeout_secs: 60,
482 tool_matcher: None,
483 }
484 }
485}
486
487pub struct FnHookBuilder {
488 name: String,
489 events: Vec<HookEvent>,
490 priority: i32,
491 timeout_secs: u64,
492 tool_matcher: Option<Regex>,
493}
494
495impl FnHookBuilder {
496 pub fn priority(mut self, priority: i32) -> Self {
497 self.priority = priority;
498 self
499 }
500
501 pub fn timeout_secs(mut self, secs: u64) -> Self {
502 self.timeout_secs = secs;
503 self
504 }
505
506 pub fn tool_matcher(mut self, pattern: &str) -> Self {
507 if let Ok(regex) = Regex::new(pattern) {
508 self.tool_matcher = Some(regex);
509 }
510 self
511 }
512
513 pub fn handler<F, Fut>(self, handler: F) -> FnHook<F>
514 where
515 F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
516 Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
517 {
518 FnHook {
519 name: self.name,
520 events: self.events,
521 handler,
522 priority: self.priority,
523 timeout_secs: self.timeout_secs,
524 tool_matcher: self.tool_matcher,
525 }
526 }
527}
528
529#[async_trait]
530impl<F, Fut> Hook for FnHook<F>
531where
532 F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
533 Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
534{
535 fn name(&self) -> &str {
536 &self.name
537 }
538
539 fn events(&self) -> &[HookEvent] {
540 &self.events
541 }
542
543 fn priority(&self) -> i32 {
544 self.priority
545 }
546
547 fn timeout_secs(&self) -> u64 {
548 self.timeout_secs
549 }
550
551 fn tool_matcher(&self) -> Option<&Regex> {
552 self.tool_matcher.as_ref()
553 }
554
555 async fn execute(
556 &self,
557 input: HookInput,
558 hook_context: &HookContext,
559 ) -> Result<HookOutput, crate::Error> {
560 (self.handler)(input, hook_context.clone()).await
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567
568 #[test]
569 fn test_hook_event_display() {
570 assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
571 assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
572 assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
573 }
574
575 #[test]
576 fn test_hook_event_can_block() {
577 assert!(HookEvent::PreToolUse.can_block());
579 assert!(HookEvent::UserPromptSubmit.can_block());
580 assert!(HookEvent::SessionStart.can_block());
581 assert!(HookEvent::PreCompact.can_block());
582 assert!(HookEvent::SubagentStart.can_block());
583
584 assert!(!HookEvent::PostToolUse.can_block());
586 assert!(!HookEvent::PostToolUseFailure.can_block());
587 assert!(!HookEvent::SessionEnd.can_block());
588 assert!(!HookEvent::SubagentStop.can_block());
589 assert!(!HookEvent::Stop.can_block());
590 }
591
592 #[test]
593 fn test_hook_input_builders() {
594 let input =
595 HookInput::pre_tool_use("session-1", "Read", serde_json::json!({"path": "/tmp"}));
596 assert_eq!(input.event_type(), HookEvent::PreToolUse);
597 assert_eq!(input.tool_name(), Some("Read"));
598 assert_eq!(input.session_id, "session-1");
599
600 let input = HookInput::session_start("session-2");
601 assert_eq!(input.event_type(), HookEvent::SessionStart);
602 assert_eq!(input.session_id, "session-2");
603 }
604
605 #[test]
606 fn test_hook_output_builders() {
607 let output = HookOutput::allow();
608 assert!(output.continue_execution);
609 assert!(output.stop_reason.is_none());
610
611 let output = HookOutput::block("Dangerous operation");
612 assert!(!output.continue_execution);
613 assert_eq!(output.stop_reason, Some("Dangerous operation".to_string()));
614
615 let output = HookOutput::allow()
616 .with_system_message("Added context")
617 .with_context("More info")
618 .suppress_logging();
619 assert!(output.continue_execution);
620 assert!(output.suppress_logging);
621 assert_eq!(output.system_message, Some("Added context".to_string()));
622 assert_eq!(output.additional_context, Some("More info".to_string()));
623 }
624
625 #[test]
626 fn test_hook_event_data_accessors() {
627 let data = HookEventData::PreToolUse {
628 tool_name: "Bash".to_string(),
629 tool_input: serde_json::json!({"command": "ls"}),
630 };
631 assert_eq!(data.event_type(), HookEvent::PreToolUse);
632 assert_eq!(data.tool_name(), Some("Bash"));
633 assert!(data.tool_input().is_some());
634
635 let data = HookEventData::SessionStart;
636 assert_eq!(data.event_type(), HookEvent::SessionStart);
637 assert_eq!(data.tool_name(), None);
638 assert!(data.tool_input().is_none());
639 }
640}