1use std::path::PathBuf;
9use std::time::Instant;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14
15use crate::types::SessionId;
16
17#[derive(Debug, Clone)]
19pub struct HookContext {
20 pub session_id: SessionId,
22}
23
24#[derive(Debug, Clone, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct PreToolUseInput {
28 pub timestamp: i64,
30 pub cwd: PathBuf,
32 pub tool_name: String,
34 pub tool_args: Value,
36}
37
38#[derive(Debug, Clone, Default, Serialize)]
40#[serde(rename_all = "camelCase")]
41pub struct PreToolUseOutput {
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub permission_decision: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub permission_decision_reason: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub modified_args: Option<Value>,
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub additional_context: Option<String>,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub suppress_output: Option<bool>,
57}
58
59#[derive(Debug, Clone, Deserialize)]
61#[serde(rename_all = "camelCase")]
62pub struct PostToolUseInput {
63 pub timestamp: i64,
65 pub cwd: PathBuf,
67 pub tool_name: String,
69 pub tool_args: Value,
71 pub tool_result: Value,
73}
74
75#[derive(Debug, Clone, Default, Serialize)]
77#[serde(rename_all = "camelCase")]
78pub struct PostToolUseOutput {
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub modified_result: Option<Value>,
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub additional_context: Option<String>,
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub suppress_output: Option<bool>,
88}
89
90#[derive(Debug, Clone, Deserialize)]
92#[serde(rename_all = "camelCase")]
93pub struct UserPromptSubmittedInput {
94 pub timestamp: i64,
96 pub cwd: PathBuf,
98 pub prompt: String,
100}
101
102#[derive(Debug, Clone, Default, Serialize)]
104#[serde(rename_all = "camelCase")]
105pub struct UserPromptSubmittedOutput {
106 #[serde(skip_serializing_if = "Option::is_none")]
108 pub modified_prompt: Option<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub additional_context: Option<String>,
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub suppress_output: Option<bool>,
115}
116
117#[derive(Debug, Clone, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct SessionStartInput {
121 pub timestamp: i64,
123 pub cwd: PathBuf,
125 pub source: String,
127 #[serde(default)]
129 pub initial_prompt: Option<String>,
130}
131
132#[derive(Debug, Clone, Default, Serialize)]
134#[serde(rename_all = "camelCase")]
135pub struct SessionStartOutput {
136 #[serde(skip_serializing_if = "Option::is_none")]
138 pub additional_context: Option<String>,
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub modified_config: Option<Value>,
142}
143
144#[derive(Debug, Clone, Deserialize)]
146#[serde(rename_all = "camelCase")]
147pub struct SessionEndInput {
148 pub timestamp: i64,
150 pub cwd: PathBuf,
152 pub reason: String,
154 #[serde(default)]
156 pub final_message: Option<String>,
157 #[serde(default)]
159 pub error: Option<String>,
160}
161
162#[derive(Debug, Clone, Default, Serialize)]
164#[serde(rename_all = "camelCase")]
165pub struct SessionEndOutput {
166 #[serde(skip_serializing_if = "Option::is_none")]
168 pub suppress_output: Option<bool>,
169 #[serde(skip_serializing_if = "Option::is_none")]
171 pub cleanup_actions: Option<Vec<String>>,
172 #[serde(skip_serializing_if = "Option::is_none")]
174 pub session_summary: Option<String>,
175}
176
177#[derive(Debug, Clone, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct ErrorOccurredInput {
181 pub timestamp: i64,
183 pub cwd: PathBuf,
185 pub error: String,
187 pub error_context: String,
189 pub recoverable: bool,
191}
192
193#[derive(Debug, Clone, Default, Serialize)]
195#[serde(rename_all = "camelCase")]
196pub struct ErrorOccurredOutput {
197 #[serde(skip_serializing_if = "Option::is_none")]
199 pub suppress_output: Option<bool>,
200 #[serde(skip_serializing_if = "Option::is_none")]
202 pub error_handling: Option<String>,
203 #[serde(skip_serializing_if = "Option::is_none")]
205 pub retry_count: Option<u32>,
206 #[serde(skip_serializing_if = "Option::is_none")]
208 pub user_notification: Option<String>,
209}
210
211#[non_exhaustive]
217#[derive(Debug)]
218pub enum HookEvent {
219 PreToolUse {
221 input: PreToolUseInput,
223 ctx: HookContext,
225 },
226 PostToolUse {
228 input: PostToolUseInput,
230 ctx: HookContext,
232 },
233 UserPromptSubmitted {
235 input: UserPromptSubmittedInput,
237 ctx: HookContext,
239 },
240 SessionStart {
242 input: SessionStartInput,
244 ctx: HookContext,
246 },
247 SessionEnd {
249 input: SessionEndInput,
251 ctx: HookContext,
253 },
254 ErrorOccurred {
256 input: ErrorOccurredInput,
258 ctx: HookContext,
260 },
261}
262
263#[non_exhaustive]
268#[derive(Debug)]
269pub enum HookOutput {
270 None,
272 PreToolUse(PreToolUseOutput),
274 PostToolUse(PostToolUseOutput),
276 UserPromptSubmitted(UserPromptSubmittedOutput),
278 SessionStart(SessionStartOutput),
280 SessionEnd(SessionEndOutput),
282 ErrorOccurred(ErrorOccurredOutput),
284}
285
286impl HookOutput {
287 fn variant_name(&self) -> &'static str {
288 match self {
289 Self::None => "None",
290 Self::PreToolUse(_) => "PreToolUse",
291 Self::PostToolUse(_) => "PostToolUse",
292 Self::UserPromptSubmitted(_) => "UserPromptSubmitted",
293 Self::SessionStart(_) => "SessionStart",
294 Self::SessionEnd(_) => "SessionEnd",
295 Self::ErrorOccurred(_) => "ErrorOccurred",
296 }
297 }
298}
299
300#[async_trait]
318pub trait SessionHooks: Send + Sync + 'static {
319 async fn on_hook(&self, event: HookEvent) -> HookOutput {
323 match event {
324 HookEvent::PreToolUse { input, ctx } => self
325 .on_pre_tool_use(input, ctx)
326 .await
327 .map(HookOutput::PreToolUse)
328 .unwrap_or(HookOutput::None),
329 HookEvent::PostToolUse { input, ctx } => self
330 .on_post_tool_use(input, ctx)
331 .await
332 .map(HookOutput::PostToolUse)
333 .unwrap_or(HookOutput::None),
334 HookEvent::UserPromptSubmitted { input, ctx } => self
335 .on_user_prompt_submitted(input, ctx)
336 .await
337 .map(HookOutput::UserPromptSubmitted)
338 .unwrap_or(HookOutput::None),
339 HookEvent::SessionStart { input, ctx } => self
340 .on_session_start(input, ctx)
341 .await
342 .map(HookOutput::SessionStart)
343 .unwrap_or(HookOutput::None),
344 HookEvent::SessionEnd { input, ctx } => self
345 .on_session_end(input, ctx)
346 .await
347 .map(HookOutput::SessionEnd)
348 .unwrap_or(HookOutput::None),
349 HookEvent::ErrorOccurred { input, ctx } => self
350 .on_error_occurred(input, ctx)
351 .await
352 .map(HookOutput::ErrorOccurred)
353 .unwrap_or(HookOutput::None),
354 }
355 }
356
357 async fn on_pre_tool_use(
360 &self,
361 _input: PreToolUseInput,
362 _ctx: HookContext,
363 ) -> Option<PreToolUseOutput> {
364 None
365 }
366
367 async fn on_post_tool_use(
371 &self,
372 _input: PostToolUseInput,
373 _ctx: HookContext,
374 ) -> Option<PostToolUseOutput> {
375 None
376 }
377
378 async fn on_user_prompt_submitted(
382 &self,
383 _input: UserPromptSubmittedInput,
384 _ctx: HookContext,
385 ) -> Option<UserPromptSubmittedOutput> {
386 None
387 }
388
389 async fn on_session_start(
392 &self,
393 _input: SessionStartInput,
394 _ctx: HookContext,
395 ) -> Option<SessionStartOutput> {
396 None
397 }
398
399 async fn on_session_end(
402 &self,
403 _input: SessionEndInput,
404 _ctx: HookContext,
405 ) -> Option<SessionEndOutput> {
406 None
407 }
408
409 async fn on_error_occurred(
412 &self,
413 _input: ErrorOccurredInput,
414 _ctx: HookContext,
415 ) -> Option<ErrorOccurredOutput> {
416 None
417 }
418}
419
420pub(crate) async fn dispatch_hook(
426 hooks: &dyn SessionHooks,
427 session_id: &SessionId,
428 hook_type: &str,
429 raw_input: Value,
430) -> Result<Value, crate::Error> {
431 let ctx = HookContext {
432 session_id: session_id.clone(),
433 };
434
435 let event = match hook_type {
436 "preToolUse" => {
437 let input: PreToolUseInput = serde_json::from_value(raw_input)?;
438 HookEvent::PreToolUse { input, ctx }
439 }
440 "postToolUse" => {
441 let input: PostToolUseInput = serde_json::from_value(raw_input)?;
442 HookEvent::PostToolUse { input, ctx }
443 }
444 "userPromptSubmitted" => {
445 let input: UserPromptSubmittedInput = serde_json::from_value(raw_input)?;
446 HookEvent::UserPromptSubmitted { input, ctx }
447 }
448 "sessionStart" => {
449 let input: SessionStartInput = serde_json::from_value(raw_input)?;
450 HookEvent::SessionStart { input, ctx }
451 }
452 "sessionEnd" => {
453 let input: SessionEndInput = serde_json::from_value(raw_input)?;
454 HookEvent::SessionEnd { input, ctx }
455 }
456 "errorOccurred" => {
457 let input: ErrorOccurredInput = serde_json::from_value(raw_input)?;
458 HookEvent::ErrorOccurred { input, ctx }
459 }
460 _ => {
461 tracing::warn!(
462 hook_type = hook_type,
463 session_id = %session_id,
464 "unknown hook type"
465 );
466 return Ok(serde_json::json!({ "output": {} }));
467 }
468 };
469
470 let dispatch_start = Instant::now();
471 let output = hooks.on_hook(event).await;
472 tracing::debug!(
473 elapsed_ms = dispatch_start.elapsed().as_millis(),
474 session_id = %session_id,
475 hook_type = hook_type,
476 "SessionHooks::on_hook dispatch"
477 );
478
479 let output_value = match (hook_type, &output) {
484 (_, HookOutput::None) => None,
485 ("preToolUse", HookOutput::PreToolUse(o)) => Some(serde_json::to_value(o)?),
486 ("postToolUse", HookOutput::PostToolUse(o)) => Some(serde_json::to_value(o)?),
487 ("userPromptSubmitted", HookOutput::UserPromptSubmitted(o)) => {
488 Some(serde_json::to_value(o)?)
489 }
490 ("sessionStart", HookOutput::SessionStart(o)) => Some(serde_json::to_value(o)?),
491 ("sessionEnd", HookOutput::SessionEnd(o)) => Some(serde_json::to_value(o)?),
492 ("errorOccurred", HookOutput::ErrorOccurred(o)) => Some(serde_json::to_value(o)?),
493 _ => {
494 tracing::warn!(
495 hook_type = hook_type,
496 session_id = %session_id,
497 output_variant = output.variant_name(),
498 "hook returned mismatched output variant, treating as unregistered"
499 );
500 None
501 }
502 };
503
504 Ok(serde_json::json!({ "output": output_value.unwrap_or(Value::Object(Default::default())) }))
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 struct TestHooks;
512
513 #[async_trait]
514 impl SessionHooks for TestHooks {
515 async fn on_hook(&self, event: HookEvent) -> HookOutput {
516 match event {
517 HookEvent::PreToolUse { input, .. } => {
518 if input.tool_name == "dangerous_tool" {
519 HookOutput::PreToolUse(PreToolUseOutput {
520 permission_decision: Some("deny".to_string()),
521 permission_decision_reason: Some("blocked by policy".to_string()),
522 ..Default::default()
523 })
524 } else {
525 HookOutput::None
526 }
527 }
528 HookEvent::UserPromptSubmitted { input, .. } => {
529 HookOutput::UserPromptSubmitted(UserPromptSubmittedOutput {
530 modified_prompt: Some(format!("[prefixed] {}", input.prompt)),
531 ..Default::default()
532 })
533 }
534 _ => HookOutput::None,
535 }
536 }
537 }
538
539 #[tokio::test]
540 async fn dispatch_pre_tool_use_deny() {
541 let hooks = TestHooks;
542 let input = serde_json::json!({
543 "timestamp": 1234567890,
544 "cwd": "/tmp",
545 "toolName": "dangerous_tool",
546 "toolArgs": {}
547 });
548 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
549 .await
550 .unwrap();
551 let output = &result["output"];
552 assert_eq!(output["permissionDecision"], "deny");
553 assert_eq!(output["permissionDecisionReason"], "blocked by policy");
554 }
555
556 #[tokio::test]
557 async fn dispatch_pre_tool_use_passthrough() {
558 let hooks = TestHooks;
559 let input = serde_json::json!({
560 "timestamp": 1234567890,
561 "cwd": "/tmp",
562 "toolName": "safe_tool",
563 "toolArgs": {"key": "value"}
564 });
565 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
566 .await
567 .unwrap();
568 assert_eq!(result["output"], serde_json::json!({}));
570 }
571
572 #[tokio::test]
573 async fn dispatch_user_prompt_submitted() {
574 let hooks = TestHooks;
575 let input = serde_json::json!({
576 "timestamp": 1234567890,
577 "cwd": "/tmp",
578 "prompt": "hello world"
579 });
580 let result = dispatch_hook(
581 &hooks,
582 &SessionId::new("sess-1"),
583 "userPromptSubmitted",
584 input,
585 )
586 .await
587 .unwrap();
588 assert_eq!(result["output"]["modifiedPrompt"], "[prefixed] hello world");
589 }
590
591 #[tokio::test]
592 async fn dispatch_unregistered_hook_returns_empty() {
593 let hooks = TestHooks;
594 let input = serde_json::json!({
595 "timestamp": 1234567890,
596 "cwd": "/tmp",
597 "reason": "complete"
598 });
599 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionEnd", input)
601 .await
602 .unwrap();
603 assert_eq!(result["output"], serde_json::json!({}));
604 }
605
606 #[tokio::test]
607 async fn dispatch_unknown_hook_type() {
608 let hooks = TestHooks;
609 let input = serde_json::json!({});
610 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "unknownHook", input)
611 .await
612 .unwrap();
613 assert_eq!(result["output"], serde_json::json!({}));
614 }
615
616 #[tokio::test]
617 async fn dispatch_mismatched_output_returns_empty() {
618 struct MismatchHooks;
619 #[async_trait]
620 impl SessionHooks for MismatchHooks {
621 async fn on_hook(&self, _event: HookEvent) -> HookOutput {
622 HookOutput::SessionEnd(SessionEndOutput {
624 session_summary: Some("oops".to_string()),
625 ..Default::default()
626 })
627 }
628 }
629
630 let hooks = MismatchHooks;
631 let input = serde_json::json!({
632 "timestamp": 1234567890,
633 "cwd": "/tmp",
634 "toolName": "some_tool",
635 "toolArgs": {}
636 });
637 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
639 .await
640 .unwrap();
641 assert_eq!(result["output"], serde_json::json!({}));
642 }
643
644 #[tokio::test]
645 async fn dispatch_post_tool_use_default() {
646 let hooks = TestHooks;
647 let input = serde_json::json!({
648 "timestamp": 1234567890,
649 "cwd": "/tmp",
650 "toolName": "some_tool",
651 "toolArgs": {},
652 "toolResult": "success"
653 });
654 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "postToolUse", input)
655 .await
656 .unwrap();
657 assert_eq!(result["output"], serde_json::json!({}));
658 }
659
660 #[tokio::test]
661 async fn dispatch_session_start() {
662 struct StartHooks;
663 #[async_trait]
664 impl SessionHooks for StartHooks {
665 async fn on_hook(&self, event: HookEvent) -> HookOutput {
666 match event {
667 HookEvent::SessionStart { .. } => {
668 HookOutput::SessionStart(SessionStartOutput {
669 additional_context: Some("extra context".to_string()),
670 ..Default::default()
671 })
672 }
673 _ => HookOutput::None,
674 }
675 }
676 }
677
678 let hooks = StartHooks;
679 let input = serde_json::json!({
680 "timestamp": 1234567890,
681 "cwd": "/tmp",
682 "source": "new"
683 });
684 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionStart", input)
685 .await
686 .unwrap();
687 assert_eq!(result["output"]["additionalContext"], "extra context");
688 }
689
690 #[tokio::test]
691 async fn dispatch_error_occurred() {
692 struct ErrorHooks;
693 #[async_trait]
694 impl SessionHooks for ErrorHooks {
695 async fn on_hook(&self, event: HookEvent) -> HookOutput {
696 match event {
697 HookEvent::ErrorOccurred { .. } => {
698 HookOutput::ErrorOccurred(ErrorOccurredOutput {
699 error_handling: Some("retry".to_string()),
700 retry_count: Some(3),
701 ..Default::default()
702 })
703 }
704 _ => HookOutput::None,
705 }
706 }
707 }
708
709 let hooks = ErrorHooks;
710 let input = serde_json::json!({
711 "timestamp": 1234567890,
712 "cwd": "/tmp",
713 "error": "model timeout",
714 "errorContext": "model_call",
715 "recoverable": true
716 });
717 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "errorOccurred", input)
718 .await
719 .unwrap();
720 assert_eq!(result["output"]["errorHandling"], "retry");
721 assert_eq!(result["output"]["retryCount"], 3);
722 }
723}