1use crate::types::ToolOutput;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio_util::sync::CancellationToken;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum HookEvent {
14 PreToolUse,
15 PostToolUse,
16 PostToolUseFailure,
17 UserPromptSubmit,
18 Stop,
19 SubagentStart,
20 SubagentStop,
21 PreCompact,
22 SessionStart,
23 SessionEnd,
24}
25
26impl HookEvent {
27 pub fn can_block(&self) -> bool {
28 matches!(self, Self::PreToolUse | Self::UserPromptSubmit)
29 }
30
31 pub fn can_modify_input(&self) -> bool {
32 matches!(self, Self::PreToolUse | Self::UserPromptSubmit)
33 }
34
35 pub fn all() -> &'static [HookEvent] {
36 &[
37 Self::PreToolUse,
38 Self::PostToolUse,
39 Self::PostToolUseFailure,
40 Self::UserPromptSubmit,
41 Self::Stop,
42 Self::SubagentStart,
43 Self::SubagentStop,
44 Self::PreCompact,
45 Self::SessionStart,
46 Self::SessionEnd,
47 ]
48 }
49}
50
51impl std::fmt::Display for HookEvent {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 let s = match self {
54 Self::PreToolUse => "pre_tool_use",
55 Self::PostToolUse => "post_tool_use",
56 Self::PostToolUseFailure => "post_tool_use_failure",
57 Self::UserPromptSubmit => "user_prompt_submit",
58 Self::Stop => "stop",
59 Self::SubagentStart => "subagent_start",
60 Self::SubagentStop => "subagent_stop",
61 Self::PreCompact => "pre_compact",
62 Self::SessionStart => "session_start",
63 Self::SessionEnd => "session_end",
64 };
65 write!(f, "{}", s)
66 }
67}
68
69#[derive(Clone, Debug)]
70pub enum HookEventData {
71 PreToolUse {
72 tool_name: String,
73 tool_input: Value,
74 },
75 PostToolUse {
76 tool_name: String,
77 tool_result: ToolOutput,
78 },
79 PostToolUseFailure {
80 tool_name: String,
81 error: String,
82 },
83 UserPromptSubmit {
84 prompt: String,
85 },
86 Stop,
87 SubagentStart {
88 subagent_id: String,
89 subagent_type: String,
90 description: String,
91 },
92 SubagentStop {
93 subagent_id: String,
94 success: bool,
95 error: Option<String>,
96 },
97 PreCompact,
98 SessionStart,
99 SessionEnd,
100}
101
102impl HookEventData {
103 pub fn event_type(&self) -> HookEvent {
104 match self {
105 Self::PreToolUse { .. } => HookEvent::PreToolUse,
106 Self::PostToolUse { .. } => HookEvent::PostToolUse,
107 Self::PostToolUseFailure { .. } => HookEvent::PostToolUseFailure,
108 Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
109 Self::Stop => HookEvent::Stop,
110 Self::SubagentStart { .. } => HookEvent::SubagentStart,
111 Self::SubagentStop { .. } => HookEvent::SubagentStop,
112 Self::PreCompact => HookEvent::PreCompact,
113 Self::SessionStart => HookEvent::SessionStart,
114 Self::SessionEnd => HookEvent::SessionEnd,
115 }
116 }
117
118 pub fn tool_name(&self) -> Option<&str> {
119 match self {
120 Self::PreToolUse { tool_name, .. }
121 | Self::PostToolUse { tool_name, .. }
122 | Self::PostToolUseFailure { tool_name, .. } => Some(tool_name),
123 _ => None,
124 }
125 }
126
127 pub fn tool_input(&self) -> Option<&Value> {
128 match self {
129 Self::PreToolUse { tool_input, .. } => Some(tool_input),
130 _ => None,
131 }
132 }
133
134 pub fn subagent_id(&self) -> Option<&str> {
135 match self {
136 Self::SubagentStart { subagent_id, .. } | Self::SubagentStop { subagent_id, .. } => {
137 Some(subagent_id)
138 }
139 _ => None,
140 }
141 }
142}
143
144#[derive(Clone, Debug)]
145pub struct HookInput {
146 pub session_id: String,
147 pub timestamp: DateTime<Utc>,
148 pub data: HookEventData,
149 pub metadata: Option<Value>,
150}
151
152impl HookInput {
153 pub fn new(session_id: impl Into<String>, data: HookEventData) -> Self {
154 Self {
155 session_id: session_id.into(),
156 timestamp: Utc::now(),
157 data,
158 metadata: None,
159 }
160 }
161
162 pub fn with_metadata(mut self, metadata: Value) -> Self {
163 self.metadata = Some(metadata);
164 self
165 }
166
167 pub fn event_type(&self) -> HookEvent {
168 self.data.event_type()
169 }
170
171 pub fn tool_name(&self) -> Option<&str> {
172 self.data.tool_name()
173 }
174
175 pub fn subagent_id(&self) -> Option<&str> {
176 self.data.subagent_id()
177 }
178
179 pub fn pre_tool_use(
180 session_id: impl Into<String>,
181 tool_name: impl Into<String>,
182 tool_input: Value,
183 ) -> Self {
184 Self::new(
185 session_id,
186 HookEventData::PreToolUse {
187 tool_name: tool_name.into(),
188 tool_input,
189 },
190 )
191 }
192
193 pub fn post_tool_use(
194 session_id: impl Into<String>,
195 tool_name: impl Into<String>,
196 tool_result: ToolOutput,
197 ) -> Self {
198 Self::new(
199 session_id,
200 HookEventData::PostToolUse {
201 tool_name: tool_name.into(),
202 tool_result,
203 },
204 )
205 }
206
207 pub fn post_tool_use_failure(
208 session_id: impl Into<String>,
209 tool_name: impl Into<String>,
210 error: impl Into<String>,
211 ) -> Self {
212 Self::new(
213 session_id,
214 HookEventData::PostToolUseFailure {
215 tool_name: tool_name.into(),
216 error: error.into(),
217 },
218 )
219 }
220
221 pub fn user_prompt_submit(session_id: impl Into<String>, prompt: impl Into<String>) -> Self {
222 Self::new(
223 session_id,
224 HookEventData::UserPromptSubmit {
225 prompt: prompt.into(),
226 },
227 )
228 }
229
230 pub fn session_start(session_id: impl Into<String>) -> Self {
231 Self::new(session_id, HookEventData::SessionStart)
232 }
233
234 pub fn session_end(session_id: impl Into<String>) -> Self {
235 Self::new(session_id, HookEventData::SessionEnd)
236 }
237
238 pub fn stop(session_id: impl Into<String>) -> Self {
239 Self::new(session_id, HookEventData::Stop)
240 }
241
242 pub fn pre_compact(session_id: impl Into<String>) -> Self {
243 Self::new(session_id, HookEventData::PreCompact)
244 }
245
246 pub fn subagent_start(
247 session_id: impl Into<String>,
248 subagent_id: impl Into<String>,
249 subagent_type: impl Into<String>,
250 description: impl Into<String>,
251 ) -> Self {
252 Self::new(
253 session_id,
254 HookEventData::SubagentStart {
255 subagent_id: subagent_id.into(),
256 subagent_type: subagent_type.into(),
257 description: description.into(),
258 },
259 )
260 }
261
262 pub fn subagent_stop(
263 session_id: impl Into<String>,
264 subagent_id: impl Into<String>,
265 success: bool,
266 error: Option<String>,
267 ) -> Self {
268 Self::new(
269 session_id,
270 HookEventData::SubagentStop {
271 subagent_id: subagent_id.into(),
272 success,
273 error,
274 },
275 )
276 }
277}
278
279#[derive(Clone, Debug, Default)]
280pub struct HookOutput {
281 pub continue_execution: bool,
282 pub stop_reason: Option<String>,
283 pub suppress_logging: bool,
284 pub system_message: Option<String>,
285 pub updated_input: Option<Value>,
286 pub additional_context: Option<String>,
287}
288
289impl HookOutput {
290 pub fn allow() -> Self {
291 Self {
292 continue_execution: true,
293 ..Default::default()
294 }
295 }
296
297 pub fn block(reason: impl Into<String>) -> Self {
298 Self {
299 continue_execution: false,
300 stop_reason: Some(reason.into()),
301 ..Default::default()
302 }
303 }
304
305 pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
306 self.system_message = Some(message.into());
307 self
308 }
309
310 pub fn with_context(mut self, context: impl Into<String>) -> Self {
311 self.additional_context = Some(context.into());
312 self
313 }
314
315 pub fn with_updated_input(mut self, input: Value) -> Self {
316 self.updated_input = Some(input);
317 self
318 }
319
320 pub fn suppress_logging(mut self) -> Self {
321 self.suppress_logging = true;
322 self
323 }
324}
325
326#[derive(Clone, Debug)]
327pub struct HookContext {
328 pub session_id: String,
329 pub cancellation_token: CancellationToken,
330 pub cwd: Option<std::path::PathBuf>,
331 pub env: std::collections::HashMap<String, String>,
332}
333
334impl Default for HookContext {
335 fn default() -> Self {
336 Self {
337 session_id: String::new(),
338 cancellation_token: CancellationToken::new(),
339 cwd: None,
340 env: std::collections::HashMap::new(),
341 }
342 }
343}
344
345impl HookContext {
346 pub fn new(session_id: impl Into<String>) -> Self {
347 Self {
348 session_id: session_id.into(),
349 ..Default::default()
350 }
351 }
352
353 pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
354 self.cancellation_token = token;
355 self
356 }
357
358 pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
359 self.cwd = Some(cwd.into());
360 self
361 }
362
363 pub fn with_env(mut self, env: std::collections::HashMap<String, String>) -> Self {
364 self.env = env;
365 self
366 }
367}
368
369#[derive(Clone, Debug)]
371pub struct HookMetadata {
372 pub name: String,
373 pub events: Vec<HookEvent>,
374 pub priority: i32,
375 pub timeout_secs: u64,
376 pub tool_matcher: Option<Regex>,
377}
378
379impl HookMetadata {
380 pub fn new(name: impl Into<String>, events: Vec<HookEvent>) -> Self {
381 Self {
382 name: name.into(),
383 events,
384 priority: 0,
385 timeout_secs: 60,
386 tool_matcher: None,
387 }
388 }
389
390 pub fn with_priority(mut self, priority: i32) -> Self {
391 self.priority = priority;
392 self
393 }
394
395 pub fn with_timeout(mut self, secs: u64) -> Self {
396 self.timeout_secs = secs;
397 self
398 }
399
400 pub fn with_tool_matcher(mut self, pattern: &str) -> Self {
401 if let Ok(regex) = Regex::new(pattern) {
402 self.tool_matcher = Some(regex);
403 }
404 self
405 }
406}
407
408#[async_trait]
409pub trait Hook: Send + Sync {
410 fn name(&self) -> &str;
411 fn events(&self) -> &[HookEvent];
412
413 #[inline]
414 fn tool_matcher(&self) -> Option<&Regex> {
415 None
416 }
417
418 #[inline]
419 fn timeout_secs(&self) -> u64 {
420 60
421 }
422
423 #[inline]
424 fn priority(&self) -> i32 {
425 0
426 }
427
428 async fn execute(
429 &self,
430 input: HookInput,
431 hook_context: &HookContext,
432 ) -> Result<HookOutput, crate::Error>;
433
434 fn metadata(&self) -> HookMetadata {
436 HookMetadata {
437 name: self.name().to_string(),
438 events: self.events().to_vec(),
439 priority: self.priority(),
440 timeout_secs: self.timeout_secs(),
441 tool_matcher: self.tool_matcher().cloned(),
442 }
443 }
444}
445
446pub struct FnHook<F> {
447 name: String,
448 events: Vec<HookEvent>,
449 handler: F,
450 priority: i32,
451 timeout_secs: u64,
452 tool_matcher: Option<Regex>,
453}
454
455impl<F> FnHook<F> {
456 pub fn builder(name: impl Into<String>, events: Vec<HookEvent>) -> FnHookBuilder {
457 FnHookBuilder {
458 name: name.into(),
459 events,
460 priority: 0,
461 timeout_secs: 60,
462 tool_matcher: None,
463 }
464 }
465}
466
467pub struct FnHookBuilder {
468 name: String,
469 events: Vec<HookEvent>,
470 priority: i32,
471 timeout_secs: u64,
472 tool_matcher: Option<Regex>,
473}
474
475impl FnHookBuilder {
476 pub fn priority(mut self, priority: i32) -> Self {
477 self.priority = priority;
478 self
479 }
480
481 pub fn timeout_secs(mut self, secs: u64) -> Self {
482 self.timeout_secs = secs;
483 self
484 }
485
486 pub fn tool_matcher(mut self, pattern: &str) -> Self {
487 if let Ok(regex) = Regex::new(pattern) {
488 self.tool_matcher = Some(regex);
489 }
490 self
491 }
492
493 pub fn handler<F, Fut>(self, handler: F) -> FnHook<F>
494 where
495 F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
496 Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
497 {
498 FnHook {
499 name: self.name,
500 events: self.events,
501 handler,
502 priority: self.priority,
503 timeout_secs: self.timeout_secs,
504 tool_matcher: self.tool_matcher,
505 }
506 }
507}
508
509#[async_trait]
510impl<F, Fut> Hook for FnHook<F>
511where
512 F: Fn(HookInput, HookContext) -> Fut + Send + Sync,
513 Fut: std::future::Future<Output = Result<HookOutput, crate::Error>> + Send,
514{
515 fn name(&self) -> &str {
516 &self.name
517 }
518
519 fn events(&self) -> &[HookEvent] {
520 &self.events
521 }
522
523 fn priority(&self) -> i32 {
524 self.priority
525 }
526
527 fn timeout_secs(&self) -> u64 {
528 self.timeout_secs
529 }
530
531 fn tool_matcher(&self) -> Option<&Regex> {
532 self.tool_matcher.as_ref()
533 }
534
535 async fn execute(
536 &self,
537 input: HookInput,
538 hook_context: &HookContext,
539 ) -> Result<HookOutput, crate::Error> {
540 (self.handler)(input, hook_context.clone()).await
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[test]
549 fn test_hook_event_display() {
550 assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
551 assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
552 assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
553 }
554
555 #[test]
556 fn test_hook_event_can_block() {
557 assert!(HookEvent::PreToolUse.can_block());
558 assert!(HookEvent::UserPromptSubmit.can_block());
559 assert!(!HookEvent::PostToolUse.can_block());
560 assert!(!HookEvent::SessionEnd.can_block());
561 assert!(!HookEvent::SubagentStart.can_block());
562 assert!(!HookEvent::SubagentStop.can_block());
563 }
564
565 #[test]
566 fn test_hook_input_builders() {
567 let input =
568 HookInput::pre_tool_use("session-1", "Read", serde_json::json!({"path": "/tmp"}));
569 assert_eq!(input.event_type(), HookEvent::PreToolUse);
570 assert_eq!(input.tool_name(), Some("Read"));
571 assert_eq!(input.session_id, "session-1");
572
573 let input = HookInput::session_start("session-2");
574 assert_eq!(input.event_type(), HookEvent::SessionStart);
575 assert_eq!(input.session_id, "session-2");
576 }
577
578 #[test]
579 fn test_hook_output_builders() {
580 let output = HookOutput::allow();
581 assert!(output.continue_execution);
582 assert!(output.stop_reason.is_none());
583
584 let output = HookOutput::block("Dangerous operation");
585 assert!(!output.continue_execution);
586 assert_eq!(output.stop_reason, Some("Dangerous operation".to_string()));
587
588 let output = HookOutput::allow()
589 .with_system_message("Added context")
590 .with_context("More info")
591 .suppress_logging();
592 assert!(output.continue_execution);
593 assert!(output.suppress_logging);
594 assert_eq!(output.system_message, Some("Added context".to_string()));
595 assert_eq!(output.additional_context, Some("More info".to_string()));
596 }
597
598 #[test]
599 fn test_hook_event_data_accessors() {
600 let data = HookEventData::PreToolUse {
601 tool_name: "Bash".to_string(),
602 tool_input: serde_json::json!({"command": "ls"}),
603 };
604 assert_eq!(data.event_type(), HookEvent::PreToolUse);
605 assert_eq!(data.tool_name(), Some("Bash"));
606 assert!(data.tool_input().is_some());
607
608 let data = HookEventData::SessionStart;
609 assert_eq!(data.event_type(), HookEvent::SessionStart);
610 assert_eq!(data.tool_name(), None);
611 assert!(data.tool_input().is_none());
612 }
613}