1use std::path::PathBuf;
9use std::time::Instant;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14
15use crate::types::SessionId;
16
17#[derive(Debug, Clone)]
19pub struct HookContext {
20 pub session_id: SessionId,
22}
23
24#[derive(Debug, Clone, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct PreToolUseInput {
28 pub session_id: String,
30 pub timestamp: i64,
32 #[serde(rename = "cwd")]
34 pub working_directory: PathBuf,
35 pub tool_name: String,
37 pub tool_args: Value,
39}
40
41#[derive(Debug, Clone, Default, Serialize)]
43#[serde(rename_all = "camelCase")]
44pub struct PreToolUseOutput {
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub permission_decision: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub permission_decision_reason: Option<String>,
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub modified_args: Option<Value>,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub additional_context: Option<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub suppress_output: Option<bool>,
60}
61
62#[derive(Debug, Clone, Deserialize)]
64#[serde(rename_all = "camelCase")]
65pub struct PreMcpToolCallInput {
66 pub session_id: String,
68 pub timestamp: i64,
70 #[serde(rename = "cwd")]
72 pub working_directory: PathBuf,
73 pub server_name: String,
75 pub tool_name: String,
77 pub arguments: Value,
79 #[serde(default)]
81 pub tool_call_id: Option<String>,
82 #[serde(default, rename = "_meta")]
84 pub meta: Option<Value>,
85}
86
87#[derive(Debug, Clone, Default, Serialize)]
94#[serde(rename_all = "camelCase")]
95pub struct PreMcpToolCallOutput {
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub meta_to_use: Option<Value>,
99}
100
101#[derive(Debug, Clone, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct PostToolUseInput {
105 pub session_id: String,
107 pub timestamp: i64,
109 #[serde(rename = "cwd")]
111 pub working_directory: PathBuf,
112 pub tool_name: String,
114 pub tool_args: Value,
116 pub tool_result: Value,
118}
119
120#[derive(Debug, Clone, Default, Serialize)]
122#[serde(rename_all = "camelCase")]
123pub struct PostToolUseOutput {
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub modified_result: Option<Value>,
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub additional_context: Option<String>,
130 #[serde(skip_serializing_if = "Option::is_none")]
132 pub suppress_output: Option<bool>,
133}
134
135#[derive(Debug, Clone, Deserialize)]
137#[serde(rename_all = "camelCase")]
138pub struct UserPromptSubmittedInput {
139 pub session_id: String,
141 pub timestamp: i64,
143 #[serde(rename = "cwd")]
145 pub working_directory: PathBuf,
146 pub prompt: String,
148}
149
150#[derive(Debug, Clone, Default, Serialize)]
152#[serde(rename_all = "camelCase")]
153pub struct UserPromptSubmittedOutput {
154 #[serde(skip_serializing_if = "Option::is_none")]
156 pub modified_prompt: Option<String>,
157 #[serde(skip_serializing_if = "Option::is_none")]
159 pub additional_context: Option<String>,
160 #[serde(skip_serializing_if = "Option::is_none")]
162 pub suppress_output: Option<bool>,
163}
164
165#[derive(Debug, Clone, Deserialize)]
167#[serde(rename_all = "camelCase")]
168pub struct SessionStartInput {
169 pub session_id: String,
171 pub timestamp: i64,
173 #[serde(rename = "cwd")]
175 pub working_directory: PathBuf,
176 pub source: String,
178 #[serde(default)]
180 pub initial_prompt: Option<String>,
181}
182
183#[derive(Debug, Clone, Default, Serialize)]
185#[serde(rename_all = "camelCase")]
186pub struct SessionStartOutput {
187 #[serde(skip_serializing_if = "Option::is_none")]
189 pub additional_context: Option<String>,
190 #[serde(skip_serializing_if = "Option::is_none")]
192 pub modified_config: Option<Value>,
193}
194
195#[derive(Debug, Clone, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct SessionEndInput {
199 pub session_id: String,
201 pub timestamp: i64,
203 #[serde(rename = "cwd")]
205 pub working_directory: PathBuf,
206 pub reason: String,
208 #[serde(default)]
210 pub final_message: Option<String>,
211 #[serde(default)]
213 pub error: Option<String>,
214}
215
216#[derive(Debug, Clone, Default, Serialize)]
218#[serde(rename_all = "camelCase")]
219pub struct SessionEndOutput {
220 #[serde(skip_serializing_if = "Option::is_none")]
222 pub suppress_output: Option<bool>,
223 #[serde(skip_serializing_if = "Option::is_none")]
225 pub cleanup_actions: Option<Vec<String>>,
226 #[serde(skip_serializing_if = "Option::is_none")]
228 pub session_summary: Option<String>,
229}
230
231#[derive(Debug, Clone, Deserialize)]
233#[serde(rename_all = "camelCase")]
234pub struct ErrorOccurredInput {
235 pub session_id: String,
237 pub timestamp: i64,
239 #[serde(rename = "cwd")]
241 pub working_directory: PathBuf,
242 pub error: String,
244 pub error_context: String,
246 pub recoverable: bool,
248}
249
250#[derive(Debug, Clone, Default, Serialize)]
252#[serde(rename_all = "camelCase")]
253pub struct ErrorOccurredOutput {
254 #[serde(skip_serializing_if = "Option::is_none")]
256 pub suppress_output: Option<bool>,
257 #[serde(skip_serializing_if = "Option::is_none")]
259 pub error_handling: Option<String>,
260 #[serde(skip_serializing_if = "Option::is_none")]
262 pub retry_count: Option<u32>,
263 #[serde(skip_serializing_if = "Option::is_none")]
265 pub user_notification: Option<String>,
266}
267
268#[non_exhaustive]
274#[derive(Debug)]
275pub enum HookEvent {
276 PreToolUse {
278 input: PreToolUseInput,
280 ctx: HookContext,
282 },
283 PreMcpToolCall {
285 input: PreMcpToolCallInput,
287 ctx: HookContext,
289 },
290 PostToolUse {
292 input: PostToolUseInput,
294 ctx: HookContext,
296 },
297 UserPromptSubmitted {
299 input: UserPromptSubmittedInput,
301 ctx: HookContext,
303 },
304 SessionStart {
306 input: SessionStartInput,
308 ctx: HookContext,
310 },
311 SessionEnd {
313 input: SessionEndInput,
315 ctx: HookContext,
317 },
318 ErrorOccurred {
320 input: ErrorOccurredInput,
322 ctx: HookContext,
324 },
325}
326
327#[non_exhaustive]
332#[derive(Debug)]
333pub enum HookOutput {
334 None,
336 PreToolUse(PreToolUseOutput),
338 PreMcpToolCall(PreMcpToolCallOutput),
340 PostToolUse(PostToolUseOutput),
342 UserPromptSubmitted(UserPromptSubmittedOutput),
344 SessionStart(SessionStartOutput),
346 SessionEnd(SessionEndOutput),
348 ErrorOccurred(ErrorOccurredOutput),
350}
351
352impl HookOutput {
353 fn variant_name(&self) -> &'static str {
354 match self {
355 Self::None => "None",
356 Self::PreToolUse(_) => "PreToolUse",
357 Self::PreMcpToolCall(_) => "PreMcpToolCall",
358 Self::PostToolUse(_) => "PostToolUse",
359 Self::UserPromptSubmitted(_) => "UserPromptSubmitted",
360 Self::SessionStart(_) => "SessionStart",
361 Self::SessionEnd(_) => "SessionEnd",
362 Self::ErrorOccurred(_) => "ErrorOccurred",
363 }
364 }
365}
366
367#[async_trait]
385pub trait SessionHooks: Send + Sync + 'static {
386 async fn on_hook(&self, event: HookEvent) -> HookOutput {
390 match event {
391 HookEvent::PreToolUse { input, ctx } => self
392 .on_pre_tool_use(input, ctx)
393 .await
394 .map(HookOutput::PreToolUse)
395 .unwrap_or(HookOutput::None),
396 HookEvent::PreMcpToolCall { input, ctx } => self
397 .on_pre_mcp_tool_call(input, ctx)
398 .await
399 .map(HookOutput::PreMcpToolCall)
400 .unwrap_or(HookOutput::None),
401 HookEvent::PostToolUse { input, ctx } => self
402 .on_post_tool_use(input, ctx)
403 .await
404 .map(HookOutput::PostToolUse)
405 .unwrap_or(HookOutput::None),
406 HookEvent::UserPromptSubmitted { input, ctx } => self
407 .on_user_prompt_submitted(input, ctx)
408 .await
409 .map(HookOutput::UserPromptSubmitted)
410 .unwrap_or(HookOutput::None),
411 HookEvent::SessionStart { input, ctx } => self
412 .on_session_start(input, ctx)
413 .await
414 .map(HookOutput::SessionStart)
415 .unwrap_or(HookOutput::None),
416 HookEvent::SessionEnd { input, ctx } => self
417 .on_session_end(input, ctx)
418 .await
419 .map(HookOutput::SessionEnd)
420 .unwrap_or(HookOutput::None),
421 HookEvent::ErrorOccurred { input, ctx } => self
422 .on_error_occurred(input, ctx)
423 .await
424 .map(HookOutput::ErrorOccurred)
425 .unwrap_or(HookOutput::None),
426 }
427 }
428
429 async fn on_pre_tool_use(
432 &self,
433 _input: PreToolUseInput,
434 _ctx: HookContext,
435 ) -> Option<PreToolUseOutput> {
436 None
437 }
438
439 async fn on_pre_mcp_tool_call(
442 &self,
443 _input: PreMcpToolCallInput,
444 _ctx: HookContext,
445 ) -> Option<PreMcpToolCallOutput> {
446 None
447 }
448
449 async fn on_post_tool_use(
453 &self,
454 _input: PostToolUseInput,
455 _ctx: HookContext,
456 ) -> Option<PostToolUseOutput> {
457 None
458 }
459
460 async fn on_user_prompt_submitted(
464 &self,
465 _input: UserPromptSubmittedInput,
466 _ctx: HookContext,
467 ) -> Option<UserPromptSubmittedOutput> {
468 None
469 }
470
471 async fn on_session_start(
474 &self,
475 _input: SessionStartInput,
476 _ctx: HookContext,
477 ) -> Option<SessionStartOutput> {
478 None
479 }
480
481 async fn on_session_end(
484 &self,
485 _input: SessionEndInput,
486 _ctx: HookContext,
487 ) -> Option<SessionEndOutput> {
488 None
489 }
490
491 async fn on_error_occurred(
494 &self,
495 _input: ErrorOccurredInput,
496 _ctx: HookContext,
497 ) -> Option<ErrorOccurredOutput> {
498 None
499 }
500}
501
502pub(crate) async fn dispatch_hook(
508 hooks: &dyn SessionHooks,
509 session_id: &SessionId,
510 hook_type: &str,
511 raw_input: Value,
512) -> Result<Value, crate::Error> {
513 let ctx = HookContext {
514 session_id: session_id.clone(),
515 };
516
517 let event = match hook_type {
518 "preToolUse" => {
519 let input: PreToolUseInput = serde_json::from_value(raw_input)?;
520 HookEvent::PreToolUse { input, ctx }
521 }
522 "preMcpToolCall" => {
523 let input: PreMcpToolCallInput = serde_json::from_value(raw_input)?;
524 HookEvent::PreMcpToolCall { input, ctx }
525 }
526 "postToolUse" => {
527 let input: PostToolUseInput = serde_json::from_value(raw_input)?;
528 HookEvent::PostToolUse { input, ctx }
529 }
530 "userPromptSubmitted" => {
531 let input: UserPromptSubmittedInput = serde_json::from_value(raw_input)?;
532 HookEvent::UserPromptSubmitted { input, ctx }
533 }
534 "sessionStart" => {
535 let input: SessionStartInput = serde_json::from_value(raw_input)?;
536 HookEvent::SessionStart { input, ctx }
537 }
538 "sessionEnd" => {
539 let input: SessionEndInput = serde_json::from_value(raw_input)?;
540 HookEvent::SessionEnd { input, ctx }
541 }
542 "errorOccurred" => {
543 let input: ErrorOccurredInput = serde_json::from_value(raw_input)?;
544 HookEvent::ErrorOccurred { input, ctx }
545 }
546 _ => {
547 tracing::warn!(
548 hook_type = hook_type,
549 session_id = %session_id,
550 "unknown hook type"
551 );
552 return Ok(serde_json::json!({ "output": {} }));
553 }
554 };
555
556 let dispatch_start = Instant::now();
557 let output = hooks.on_hook(event).await;
558 tracing::debug!(
559 elapsed_ms = dispatch_start.elapsed().as_millis(),
560 session_id = %session_id,
561 hook_type = hook_type,
562 "SessionHooks::on_hook dispatch"
563 );
564
565 let output_value = match (hook_type, &output) {
570 (_, HookOutput::None) => None,
571 ("preToolUse", HookOutput::PreToolUse(o)) => Some(serde_json::to_value(o)?),
572 ("preMcpToolCall", HookOutput::PreMcpToolCall(o)) => Some(serde_json::to_value(o)?),
573 ("postToolUse", HookOutput::PostToolUse(o)) => Some(serde_json::to_value(o)?),
574 ("userPromptSubmitted", HookOutput::UserPromptSubmitted(o)) => {
575 Some(serde_json::to_value(o)?)
576 }
577 ("sessionStart", HookOutput::SessionStart(o)) => Some(serde_json::to_value(o)?),
578 ("sessionEnd", HookOutput::SessionEnd(o)) => Some(serde_json::to_value(o)?),
579 ("errorOccurred", HookOutput::ErrorOccurred(o)) => Some(serde_json::to_value(o)?),
580 _ => {
581 tracing::warn!(
582 hook_type = hook_type,
583 session_id = %session_id,
584 output_variant = output.variant_name(),
585 "hook returned mismatched output variant, treating as unregistered"
586 );
587 None
588 }
589 };
590
591 Ok(serde_json::json!({ "output": output_value.unwrap_or(Value::Object(Default::default())) }))
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597
598 struct TestHooks;
599
600 #[async_trait]
601 impl SessionHooks for TestHooks {
602 async fn on_hook(&self, event: HookEvent) -> HookOutput {
603 match event {
604 HookEvent::PreToolUse { input, .. } => {
605 if input.tool_name == "dangerous_tool" {
606 HookOutput::PreToolUse(PreToolUseOutput {
607 permission_decision: Some("deny".to_string()),
608 permission_decision_reason: Some("blocked by policy".to_string()),
609 ..Default::default()
610 })
611 } else {
612 HookOutput::None
613 }
614 }
615 HookEvent::UserPromptSubmitted { input, .. } => {
616 HookOutput::UserPromptSubmitted(UserPromptSubmittedOutput {
617 modified_prompt: Some(format!("[prefixed] {}", input.prompt)),
618 ..Default::default()
619 })
620 }
621 _ => HookOutput::None,
622 }
623 }
624 }
625
626 #[tokio::test]
627 async fn dispatch_pre_tool_use_deny() {
628 let hooks = TestHooks;
629 let input = serde_json::json!({
630 "sessionId": "sess-1",
631 "timestamp": 1234567890,
632 "cwd": "/tmp",
633 "toolName": "dangerous_tool",
634 "toolArgs": {}
635 });
636 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
637 .await
638 .unwrap();
639 let output = &result["output"];
640 assert_eq!(output["permissionDecision"], "deny");
641 assert_eq!(output["permissionDecisionReason"], "blocked by policy");
642 }
643
644 #[tokio::test]
645 async fn dispatch_pre_tool_use_passthrough() {
646 let hooks = TestHooks;
647 let input = serde_json::json!({
648 "sessionId": "sess-1",
649 "timestamp": 1234567890,
650 "cwd": "/tmp",
651 "toolName": "safe_tool",
652 "toolArgs": {"key": "value"}
653 });
654 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
655 .await
656 .unwrap();
657 assert_eq!(result["output"], serde_json::json!({}));
659 }
660
661 #[tokio::test]
662 async fn dispatch_user_prompt_submitted() {
663 let hooks = TestHooks;
664 let input = serde_json::json!({
665 "sessionId": "sess-1",
666 "timestamp": 1234567890,
667 "cwd": "/tmp",
668 "prompt": "hello world"
669 });
670 let result = dispatch_hook(
671 &hooks,
672 &SessionId::new("sess-1"),
673 "userPromptSubmitted",
674 input,
675 )
676 .await
677 .unwrap();
678 assert_eq!(result["output"]["modifiedPrompt"], "[prefixed] hello world");
679 }
680
681 #[tokio::test]
682 async fn dispatch_unregistered_hook_returns_empty() {
683 let hooks = TestHooks;
684 let input = serde_json::json!({
685 "sessionId": "sess-1",
686 "timestamp": 1234567890,
687 "cwd": "/tmp",
688 "reason": "complete"
689 });
690 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionEnd", input)
692 .await
693 .unwrap();
694 assert_eq!(result["output"], serde_json::json!({}));
695 }
696
697 #[tokio::test]
698 async fn dispatch_unknown_hook_type() {
699 let hooks = TestHooks;
700 let input = serde_json::json!({});
701 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "unknownHook", input)
702 .await
703 .unwrap();
704 assert_eq!(result["output"], serde_json::json!({}));
705 }
706
707 #[tokio::test]
708 async fn dispatch_mismatched_output_returns_empty() {
709 struct MismatchHooks;
710 #[async_trait]
711 impl SessionHooks for MismatchHooks {
712 async fn on_hook(&self, _event: HookEvent) -> HookOutput {
713 HookOutput::SessionEnd(SessionEndOutput {
715 session_summary: Some("oops".to_string()),
716 ..Default::default()
717 })
718 }
719 }
720
721 let hooks = MismatchHooks;
722 let input = serde_json::json!({
723 "sessionId": "sess-1",
724 "timestamp": 1234567890,
725 "cwd": "/tmp",
726 "toolName": "some_tool",
727 "toolArgs": {}
728 });
729 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
731 .await
732 .unwrap();
733 assert_eq!(result["output"], serde_json::json!({}));
734 }
735
736 #[tokio::test]
737 async fn dispatch_post_tool_use_default() {
738 let hooks = TestHooks;
739 let input = serde_json::json!({
740 "sessionId": "sess-1",
741 "timestamp": 1234567890,
742 "cwd": "/tmp",
743 "toolName": "some_tool",
744 "toolArgs": {},
745 "toolResult": "success"
746 });
747 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "postToolUse", input)
748 .await
749 .unwrap();
750 assert_eq!(result["output"], serde_json::json!({}));
751 }
752
753 #[tokio::test]
754 async fn dispatch_session_start() {
755 struct StartHooks;
756 #[async_trait]
757 impl SessionHooks for StartHooks {
758 async fn on_hook(&self, event: HookEvent) -> HookOutput {
759 match event {
760 HookEvent::SessionStart { .. } => {
761 HookOutput::SessionStart(SessionStartOutput {
762 additional_context: Some("extra context".to_string()),
763 ..Default::default()
764 })
765 }
766 _ => HookOutput::None,
767 }
768 }
769 }
770
771 let hooks = StartHooks;
772 let input = serde_json::json!({
773 "sessionId": "sess-1",
774 "timestamp": 1234567890,
775 "cwd": "/tmp",
776 "source": "new"
777 });
778 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionStart", input)
779 .await
780 .unwrap();
781 assert_eq!(result["output"]["additionalContext"], "extra context");
782 }
783
784 #[tokio::test]
785 async fn dispatch_error_occurred() {
786 struct ErrorHooks;
787 #[async_trait]
788 impl SessionHooks for ErrorHooks {
789 async fn on_hook(&self, event: HookEvent) -> HookOutput {
790 match event {
791 HookEvent::ErrorOccurred { .. } => {
792 HookOutput::ErrorOccurred(ErrorOccurredOutput {
793 error_handling: Some("retry".to_string()),
794 retry_count: Some(3),
795 ..Default::default()
796 })
797 }
798 _ => HookOutput::None,
799 }
800 }
801 }
802
803 let hooks = ErrorHooks;
804 let input = serde_json::json!({
805 "sessionId": "sess-1",
806 "timestamp": 1234567890,
807 "cwd": "/tmp",
808 "error": "model timeout",
809 "errorContext": "model_call",
810 "recoverable": true
811 });
812 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "errorOccurred", input)
813 .await
814 .unwrap();
815 assert_eq!(result["output"]["errorHandling"], "retry");
816 assert_eq!(result["output"]["retryCount"], 3);
817 }
818}