1use std::path::PathBuf;
9use std::time::Instant;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14
15use crate::types::SessionId;
16
17#[derive(Debug, Clone)]
19pub struct HookContext {
20 pub session_id: SessionId,
22}
23
24#[derive(Debug, Clone, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct PreToolUseInput {
28 pub session_id: String,
30 pub timestamp: i64,
32 #[serde(rename = "cwd")]
34 pub working_directory: PathBuf,
35 pub tool_name: String,
37 pub tool_args: Value,
39}
40
41#[derive(Debug, Clone, Default, Serialize)]
43#[serde(rename_all = "camelCase")]
44pub struct PreToolUseOutput {
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub permission_decision: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub permission_decision_reason: Option<String>,
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub modified_args: Option<Value>,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub additional_context: Option<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub suppress_output: Option<bool>,
60}
61
62#[derive(Debug, Clone, Deserialize)]
64#[serde(rename_all = "camelCase")]
65pub struct PreMcpToolCallInput {
66 pub session_id: String,
68 pub timestamp: i64,
70 #[serde(rename = "cwd")]
72 pub working_directory: PathBuf,
73 pub server_name: String,
75 pub tool_name: String,
77 pub arguments: Value,
79 #[serde(default)]
81 pub tool_call_id: Option<String>,
82 #[serde(default, rename = "_meta")]
84 pub meta: Option<Value>,
85}
86
87#[derive(Debug, Clone, Default, Serialize)]
94#[serde(rename_all = "camelCase")]
95pub struct PreMcpToolCallOutput {
96 #[serde(skip_serializing_if = "Option::is_none")]
98 pub meta_to_use: Option<Value>,
99}
100
101#[derive(Debug, Clone, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct PostToolUseInput {
105 pub session_id: String,
107 pub timestamp: i64,
109 #[serde(rename = "cwd")]
111 pub working_directory: PathBuf,
112 pub tool_name: String,
114 pub tool_args: Value,
116 pub tool_result: Value,
118}
119
120#[derive(Debug, Clone, Default, Serialize)]
122#[serde(rename_all = "camelCase")]
123pub struct PostToolUseOutput {
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub modified_result: Option<Value>,
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub additional_context: Option<String>,
130 #[serde(skip_serializing_if = "Option::is_none")]
132 pub suppress_output: Option<bool>,
133}
134
135#[derive(Debug, Clone, Deserialize)]
143#[serde(rename_all = "camelCase")]
144pub struct PostToolUseFailureInput {
145 pub session_id: String,
147 pub timestamp: i64,
149 #[serde(rename = "cwd")]
151 pub working_directory: PathBuf,
152 pub tool_name: String,
154 pub tool_args: Value,
156 pub error: String,
158}
159
160#[derive(Debug, Clone, Default, Serialize)]
165#[serde(rename_all = "camelCase")]
166pub struct PostToolUseFailureOutput {
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub additional_context: Option<String>,
170}
171
172#[derive(Debug, Clone, Deserialize)]
174#[serde(rename_all = "camelCase")]
175pub struct UserPromptSubmittedInput {
176 pub session_id: String,
178 pub timestamp: i64,
180 #[serde(rename = "cwd")]
182 pub working_directory: PathBuf,
183 pub prompt: String,
185}
186
187#[derive(Debug, Clone, Default, Serialize)]
189#[serde(rename_all = "camelCase")]
190pub struct UserPromptSubmittedOutput {
191 #[serde(skip_serializing_if = "Option::is_none")]
193 pub modified_prompt: Option<String>,
194 #[serde(skip_serializing_if = "Option::is_none")]
196 pub additional_context: Option<String>,
197 #[serde(skip_serializing_if = "Option::is_none")]
199 pub suppress_output: Option<bool>,
200}
201
202#[derive(Debug, Clone, Deserialize)]
204#[serde(rename_all = "camelCase")]
205pub struct SessionStartInput {
206 pub session_id: String,
208 pub timestamp: i64,
210 #[serde(rename = "cwd")]
212 pub working_directory: PathBuf,
213 pub source: String,
215 #[serde(default)]
217 pub initial_prompt: Option<String>,
218}
219
220#[derive(Debug, Clone, Default, Serialize)]
222#[serde(rename_all = "camelCase")]
223pub struct SessionStartOutput {
224 #[serde(skip_serializing_if = "Option::is_none")]
226 pub additional_context: Option<String>,
227 #[serde(skip_serializing_if = "Option::is_none")]
229 pub modified_config: Option<Value>,
230}
231
232#[derive(Debug, Clone, Deserialize)]
234#[serde(rename_all = "camelCase")]
235pub struct SessionEndInput {
236 pub session_id: String,
238 pub timestamp: i64,
240 #[serde(rename = "cwd")]
242 pub working_directory: PathBuf,
243 pub reason: String,
245 #[serde(default)]
247 pub final_message: Option<String>,
248 #[serde(default)]
250 pub error: Option<String>,
251}
252
253#[derive(Debug, Clone, Default, Serialize)]
255#[serde(rename_all = "camelCase")]
256pub struct SessionEndOutput {
257 #[serde(skip_serializing_if = "Option::is_none")]
259 pub suppress_output: Option<bool>,
260 #[serde(skip_serializing_if = "Option::is_none")]
262 pub cleanup_actions: Option<Vec<String>>,
263 #[serde(skip_serializing_if = "Option::is_none")]
265 pub session_summary: Option<String>,
266}
267
268#[derive(Debug, Clone, Deserialize)]
270#[serde(rename_all = "camelCase")]
271pub struct ErrorOccurredInput {
272 pub session_id: String,
274 pub timestamp: i64,
276 #[serde(rename = "cwd")]
278 pub working_directory: PathBuf,
279 pub error: String,
281 pub error_context: String,
283 pub recoverable: bool,
285}
286
287#[derive(Debug, Clone, Default, Serialize)]
289#[serde(rename_all = "camelCase")]
290pub struct ErrorOccurredOutput {
291 #[serde(skip_serializing_if = "Option::is_none")]
293 pub suppress_output: Option<bool>,
294 #[serde(skip_serializing_if = "Option::is_none")]
296 pub error_handling: Option<String>,
297 #[serde(skip_serializing_if = "Option::is_none")]
299 pub retry_count: Option<u32>,
300 #[serde(skip_serializing_if = "Option::is_none")]
302 pub user_notification: Option<String>,
303}
304
305#[non_exhaustive]
311#[derive(Debug)]
312pub enum HookEvent {
313 PreToolUse {
315 input: PreToolUseInput,
317 ctx: HookContext,
319 },
320 PreMcpToolCall {
322 input: PreMcpToolCallInput,
324 ctx: HookContext,
326 },
327 PostToolUse {
329 input: PostToolUseInput,
331 ctx: HookContext,
333 },
334 PostToolUseFailure {
338 input: PostToolUseFailureInput,
340 ctx: HookContext,
342 },
343 UserPromptSubmitted {
345 input: UserPromptSubmittedInput,
347 ctx: HookContext,
349 },
350 SessionStart {
352 input: SessionStartInput,
354 ctx: HookContext,
356 },
357 SessionEnd {
359 input: SessionEndInput,
361 ctx: HookContext,
363 },
364 ErrorOccurred {
366 input: ErrorOccurredInput,
368 ctx: HookContext,
370 },
371}
372
373#[non_exhaustive]
378#[derive(Debug)]
379pub enum HookOutput {
380 None,
382 PreToolUse(PreToolUseOutput),
384 PreMcpToolCall(PreMcpToolCallOutput),
386 PostToolUse(PostToolUseOutput),
388 PostToolUseFailure(PostToolUseFailureOutput),
390 UserPromptSubmitted(UserPromptSubmittedOutput),
392 SessionStart(SessionStartOutput),
394 SessionEnd(SessionEndOutput),
396 ErrorOccurred(ErrorOccurredOutput),
398}
399
400impl HookOutput {
401 fn variant_name(&self) -> &'static str {
402 match self {
403 Self::None => "None",
404 Self::PreToolUse(_) => "PreToolUse",
405 Self::PreMcpToolCall(_) => "PreMcpToolCall",
406 Self::PostToolUse(_) => "PostToolUse",
407 Self::PostToolUseFailure(_) => "PostToolUseFailure",
408 Self::UserPromptSubmitted(_) => "UserPromptSubmitted",
409 Self::SessionStart(_) => "SessionStart",
410 Self::SessionEnd(_) => "SessionEnd",
411 Self::ErrorOccurred(_) => "ErrorOccurred",
412 }
413 }
414}
415
416#[async_trait]
434pub trait SessionHooks: Send + Sync + 'static {
435 async fn on_hook(&self, event: HookEvent) -> HookOutput {
439 match event {
440 HookEvent::PreToolUse { input, ctx } => self
441 .on_pre_tool_use(input, ctx)
442 .await
443 .map(HookOutput::PreToolUse)
444 .unwrap_or(HookOutput::None),
445 HookEvent::PreMcpToolCall { input, ctx } => self
446 .on_pre_mcp_tool_call(input, ctx)
447 .await
448 .map(HookOutput::PreMcpToolCall)
449 .unwrap_or(HookOutput::None),
450 HookEvent::PostToolUse { input, ctx } => self
451 .on_post_tool_use(input, ctx)
452 .await
453 .map(HookOutput::PostToolUse)
454 .unwrap_or(HookOutput::None),
455 HookEvent::PostToolUseFailure { input, ctx } => self
456 .on_post_tool_use_failure(input, ctx)
457 .await
458 .map(HookOutput::PostToolUseFailure)
459 .unwrap_or(HookOutput::None),
460 HookEvent::UserPromptSubmitted { input, ctx } => self
461 .on_user_prompt_submitted(input, ctx)
462 .await
463 .map(HookOutput::UserPromptSubmitted)
464 .unwrap_or(HookOutput::None),
465 HookEvent::SessionStart { input, ctx } => self
466 .on_session_start(input, ctx)
467 .await
468 .map(HookOutput::SessionStart)
469 .unwrap_or(HookOutput::None),
470 HookEvent::SessionEnd { input, ctx } => self
471 .on_session_end(input, ctx)
472 .await
473 .map(HookOutput::SessionEnd)
474 .unwrap_or(HookOutput::None),
475 HookEvent::ErrorOccurred { input, ctx } => self
476 .on_error_occurred(input, ctx)
477 .await
478 .map(HookOutput::ErrorOccurred)
479 .unwrap_or(HookOutput::None),
480 }
481 }
482
483 async fn on_pre_tool_use(
486 &self,
487 _input: PreToolUseInput,
488 _ctx: HookContext,
489 ) -> Option<PreToolUseOutput> {
490 None
491 }
492
493 async fn on_pre_mcp_tool_call(
496 &self,
497 _input: PreMcpToolCallInput,
498 _ctx: HookContext,
499 ) -> Option<PreMcpToolCallOutput> {
500 None
501 }
502
503 async fn on_post_tool_use(
507 &self,
508 _input: PostToolUseInput,
509 _ctx: HookContext,
510 ) -> Option<PostToolUseOutput> {
511 None
512 }
513
514 async fn on_post_tool_use_failure(
519 &self,
520 _input: PostToolUseFailureInput,
521 _ctx: HookContext,
522 ) -> Option<PostToolUseFailureOutput> {
523 None
524 }
525
526 async fn on_user_prompt_submitted(
530 &self,
531 _input: UserPromptSubmittedInput,
532 _ctx: HookContext,
533 ) -> Option<UserPromptSubmittedOutput> {
534 None
535 }
536
537 async fn on_session_start(
540 &self,
541 _input: SessionStartInput,
542 _ctx: HookContext,
543 ) -> Option<SessionStartOutput> {
544 None
545 }
546
547 async fn on_session_end(
550 &self,
551 _input: SessionEndInput,
552 _ctx: HookContext,
553 ) -> Option<SessionEndOutput> {
554 None
555 }
556
557 async fn on_error_occurred(
560 &self,
561 _input: ErrorOccurredInput,
562 _ctx: HookContext,
563 ) -> Option<ErrorOccurredOutput> {
564 None
565 }
566}
567
568pub(crate) async fn dispatch_hook(
574 hooks: &dyn SessionHooks,
575 session_id: &SessionId,
576 hook_type: &str,
577 raw_input: Value,
578) -> Result<Value, crate::Error> {
579 let ctx = HookContext {
580 session_id: session_id.clone(),
581 };
582
583 let event = match hook_type {
584 "preToolUse" => {
585 let input: PreToolUseInput = serde_json::from_value(raw_input)?;
586 HookEvent::PreToolUse { input, ctx }
587 }
588 "preMcpToolCall" => {
589 let input: PreMcpToolCallInput = serde_json::from_value(raw_input)?;
590 HookEvent::PreMcpToolCall { input, ctx }
591 }
592 "postToolUse" => {
593 let input: PostToolUseInput = serde_json::from_value(raw_input)?;
594 HookEvent::PostToolUse { input, ctx }
595 }
596 "postToolUseFailure" => {
597 let input: PostToolUseFailureInput = serde_json::from_value(raw_input)?;
598 HookEvent::PostToolUseFailure { input, ctx }
599 }
600 "userPromptSubmitted" => {
601 let input: UserPromptSubmittedInput = serde_json::from_value(raw_input)?;
602 HookEvent::UserPromptSubmitted { input, ctx }
603 }
604 "sessionStart" => {
605 let input: SessionStartInput = serde_json::from_value(raw_input)?;
606 HookEvent::SessionStart { input, ctx }
607 }
608 "sessionEnd" => {
609 let input: SessionEndInput = serde_json::from_value(raw_input)?;
610 HookEvent::SessionEnd { input, ctx }
611 }
612 "errorOccurred" => {
613 let input: ErrorOccurredInput = serde_json::from_value(raw_input)?;
614 HookEvent::ErrorOccurred { input, ctx }
615 }
616 _ => {
617 tracing::warn!(
618 hook_type = hook_type,
619 session_id = %session_id,
620 "unknown hook type"
621 );
622 return Ok(serde_json::json!({ "output": {} }));
623 }
624 };
625
626 let dispatch_start = Instant::now();
627 let output = hooks.on_hook(event).await;
628 tracing::debug!(
629 elapsed_ms = dispatch_start.elapsed().as_millis(),
630 session_id = %session_id,
631 hook_type = hook_type,
632 "SessionHooks::on_hook dispatch"
633 );
634
635 let output_value = match (hook_type, &output) {
640 (_, HookOutput::None) => None,
641 ("preToolUse", HookOutput::PreToolUse(o)) => Some(serde_json::to_value(o)?),
642 ("preMcpToolCall", HookOutput::PreMcpToolCall(o)) => Some(serde_json::to_value(o)?),
643 ("postToolUse", HookOutput::PostToolUse(o)) => Some(serde_json::to_value(o)?),
644 ("postToolUseFailure", HookOutput::PostToolUseFailure(o)) => Some(serde_json::to_value(o)?),
645 ("userPromptSubmitted", HookOutput::UserPromptSubmitted(o)) => {
646 Some(serde_json::to_value(o)?)
647 }
648 ("sessionStart", HookOutput::SessionStart(o)) => Some(serde_json::to_value(o)?),
649 ("sessionEnd", HookOutput::SessionEnd(o)) => Some(serde_json::to_value(o)?),
650 ("errorOccurred", HookOutput::ErrorOccurred(o)) => Some(serde_json::to_value(o)?),
651 _ => {
652 tracing::warn!(
653 hook_type = hook_type,
654 session_id = %session_id,
655 output_variant = output.variant_name(),
656 "hook returned mismatched output variant, treating as unregistered"
657 );
658 None
659 }
660 };
661
662 Ok(serde_json::json!({ "output": output_value.unwrap_or(Value::Object(Default::default())) }))
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668
669 struct TestHooks;
670
671 #[async_trait]
672 impl SessionHooks for TestHooks {
673 async fn on_hook(&self, event: HookEvent) -> HookOutput {
674 match event {
675 HookEvent::PreToolUse { input, .. } => {
676 if input.tool_name == "dangerous_tool" {
677 HookOutput::PreToolUse(PreToolUseOutput {
678 permission_decision: Some("deny".to_string()),
679 permission_decision_reason: Some("blocked by policy".to_string()),
680 ..Default::default()
681 })
682 } else {
683 HookOutput::None
684 }
685 }
686 HookEvent::UserPromptSubmitted { input, .. } => {
687 HookOutput::UserPromptSubmitted(UserPromptSubmittedOutput {
688 modified_prompt: Some(format!("[prefixed] {}", input.prompt)),
689 ..Default::default()
690 })
691 }
692 _ => HookOutput::None,
693 }
694 }
695 }
696
697 #[tokio::test]
698 async fn dispatch_pre_tool_use_deny() {
699 let hooks = TestHooks;
700 let input = serde_json::json!({
701 "sessionId": "sess-1",
702 "timestamp": 1234567890,
703 "cwd": "/tmp",
704 "toolName": "dangerous_tool",
705 "toolArgs": {}
706 });
707 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
708 .await
709 .unwrap();
710 let output = &result["output"];
711 assert_eq!(output["permissionDecision"], "deny");
712 assert_eq!(output["permissionDecisionReason"], "blocked by policy");
713 }
714
715 #[tokio::test]
716 async fn dispatch_pre_tool_use_passthrough() {
717 let hooks = TestHooks;
718 let input = serde_json::json!({
719 "sessionId": "sess-1",
720 "timestamp": 1234567890,
721 "cwd": "/tmp",
722 "toolName": "safe_tool",
723 "toolArgs": {"key": "value"}
724 });
725 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
726 .await
727 .unwrap();
728 assert_eq!(result["output"], serde_json::json!({}));
730 }
731
732 #[tokio::test]
733 async fn dispatch_user_prompt_submitted() {
734 let hooks = TestHooks;
735 let input = serde_json::json!({
736 "sessionId": "sess-1",
737 "timestamp": 1234567890,
738 "cwd": "/tmp",
739 "prompt": "hello world"
740 });
741 let result = dispatch_hook(
742 &hooks,
743 &SessionId::new("sess-1"),
744 "userPromptSubmitted",
745 input,
746 )
747 .await
748 .unwrap();
749 assert_eq!(result["output"]["modifiedPrompt"], "[prefixed] hello world");
750 }
751
752 #[tokio::test]
753 async fn dispatch_unregistered_hook_returns_empty() {
754 let hooks = TestHooks;
755 let input = serde_json::json!({
756 "sessionId": "sess-1",
757 "timestamp": 1234567890,
758 "cwd": "/tmp",
759 "reason": "complete"
760 });
761 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionEnd", input)
763 .await
764 .unwrap();
765 assert_eq!(result["output"], serde_json::json!({}));
766 }
767
768 #[tokio::test]
769 async fn dispatch_unknown_hook_type() {
770 let hooks = TestHooks;
771 let input = serde_json::json!({});
772 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "unknownHook", input)
773 .await
774 .unwrap();
775 assert_eq!(result["output"], serde_json::json!({}));
776 }
777
778 #[tokio::test]
779 async fn dispatch_mismatched_output_returns_empty() {
780 struct MismatchHooks;
781 #[async_trait]
782 impl SessionHooks for MismatchHooks {
783 async fn on_hook(&self, _event: HookEvent) -> HookOutput {
784 HookOutput::SessionEnd(SessionEndOutput {
786 session_summary: Some("oops".to_string()),
787 ..Default::default()
788 })
789 }
790 }
791
792 let hooks = MismatchHooks;
793 let input = serde_json::json!({
794 "sessionId": "sess-1",
795 "timestamp": 1234567890,
796 "cwd": "/tmp",
797 "toolName": "some_tool",
798 "toolArgs": {}
799 });
800 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
802 .await
803 .unwrap();
804 assert_eq!(result["output"], serde_json::json!({}));
805 }
806
807 #[tokio::test]
808 async fn dispatch_post_tool_use_default() {
809 let hooks = TestHooks;
810 let input = serde_json::json!({
811 "sessionId": "sess-1",
812 "timestamp": 1234567890,
813 "cwd": "/tmp",
814 "toolName": "some_tool",
815 "toolArgs": {},
816 "toolResult": "success"
817 });
818 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "postToolUse", input)
819 .await
820 .unwrap();
821 assert_eq!(result["output"], serde_json::json!({}));
822 }
823
824 #[tokio::test]
825 async fn dispatch_post_tool_use_failure_default() {
826 let hooks = TestHooks;
828 let input = serde_json::json!({
829 "sessionId": "sess-1",
830 "timestamp": 1234567890,
831 "cwd": "/tmp",
832 "toolName": "some_tool",
833 "toolArgs": {"key": "value"},
834 "error": "boom"
835 });
836 let result = dispatch_hook(
837 &hooks,
838 &SessionId::new("sess-1"),
839 "postToolUseFailure",
840 input,
841 )
842 .await
843 .unwrap();
844 assert_eq!(result["output"], serde_json::json!({}));
845 }
846
847 #[tokio::test]
848 async fn dispatch_post_tool_use_failure_returns_additional_context() {
849 struct FailureHooks;
850 #[async_trait]
851 impl SessionHooks for FailureHooks {
852 async fn on_post_tool_use_failure(
853 &self,
854 input: PostToolUseFailureInput,
855 _ctx: HookContext,
856 ) -> Option<PostToolUseFailureOutput> {
857 assert_eq!(input.session_id, "sess-1");
858 assert_eq!(input.tool_name, "some_tool");
859 assert_eq!(input.error, "boom");
860 assert_eq!(input.working_directory, PathBuf::from("/tmp"));
861 Some(PostToolUseFailureOutput {
862 additional_context: Some(format!(
863 "tool {} failed: {}",
864 input.tool_name, input.error
865 )),
866 })
867 }
868 }
869
870 let input = serde_json::json!({
871 "sessionId": "sess-1",
872 "timestamp": 1234567890,
873 "cwd": "/tmp",
874 "toolName": "some_tool",
875 "toolArgs": {},
876 "error": "boom"
877 });
878 let result = dispatch_hook(
879 &FailureHooks,
880 &SessionId::new("sess-1"),
881 "postToolUseFailure",
882 input,
883 )
884 .await
885 .unwrap();
886 assert_eq!(
887 result["output"]["additionalContext"],
888 "tool some_tool failed: boom"
889 );
890 }
891
892 #[tokio::test]
893 async fn dispatch_post_tool_use_failure_invalid_input_errors() {
894 let hooks = TestHooks;
897 let input = serde_json::json!({
898 "sessionId": "sess-1",
899 "timestamp": 1234567890,
900 "cwd": "/tmp",
901 "toolName": "some_tool",
902 "toolArgs": {}
903 });
904 let err = dispatch_hook(
905 &hooks,
906 &SessionId::new("sess-1"),
907 "postToolUseFailure",
908 input,
909 )
910 .await
911 .unwrap_err();
912 let msg = err.to_string().to_ascii_lowercase();
913 assert!(
914 msg.contains("error") || msg.contains("missing field"),
915 "unexpected error: {msg}"
916 );
917 }
918
919 #[tokio::test]
920 async fn dispatch_session_start() {
921 struct StartHooks;
922 #[async_trait]
923 impl SessionHooks for StartHooks {
924 async fn on_hook(&self, event: HookEvent) -> HookOutput {
925 match event {
926 HookEvent::SessionStart { .. } => {
927 HookOutput::SessionStart(SessionStartOutput {
928 additional_context: Some("extra context".to_string()),
929 ..Default::default()
930 })
931 }
932 _ => HookOutput::None,
933 }
934 }
935 }
936
937 let hooks = StartHooks;
938 let input = serde_json::json!({
939 "sessionId": "sess-1",
940 "timestamp": 1234567890,
941 "cwd": "/tmp",
942 "source": "new"
943 });
944 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionStart", input)
945 .await
946 .unwrap();
947 assert_eq!(result["output"]["additionalContext"], "extra context");
948 }
949
950 #[tokio::test]
951 async fn dispatch_error_occurred() {
952 struct ErrorHooks;
953 #[async_trait]
954 impl SessionHooks for ErrorHooks {
955 async fn on_hook(&self, event: HookEvent) -> HookOutput {
956 match event {
957 HookEvent::ErrorOccurred { .. } => {
958 HookOutput::ErrorOccurred(ErrorOccurredOutput {
959 error_handling: Some("retry".to_string()),
960 retry_count: Some(3),
961 ..Default::default()
962 })
963 }
964 _ => HookOutput::None,
965 }
966 }
967 }
968
969 let hooks = ErrorHooks;
970 let input = serde_json::json!({
971 "sessionId": "sess-1",
972 "timestamp": 1234567890,
973 "cwd": "/tmp",
974 "error": "model timeout",
975 "errorContext": "model_call",
976 "recoverable": true
977 });
978 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "errorOccurred", input)
979 .await
980 .unwrap();
981 assert_eq!(result["output"]["errorHandling"], "retry");
982 assert_eq!(result["output"]["retryCount"], 3);
983 }
984}