1use std::path::PathBuf;
9use std::time::Instant;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14
15use crate::types::SessionId;
16
17#[derive(Debug, Clone)]
19pub struct HookContext {
20 pub session_id: SessionId,
22}
23
24#[derive(Debug, Clone, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct PreToolUseInput {
28 pub session_id: String,
30 pub timestamp: i64,
32 pub cwd: PathBuf,
34 pub tool_name: String,
36 pub tool_args: Value,
38}
39
40#[derive(Debug, Clone, Default, Serialize)]
42#[serde(rename_all = "camelCase")]
43pub struct PreToolUseOutput {
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub permission_decision: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub permission_decision_reason: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub modified_args: Option<Value>,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub additional_context: Option<String>,
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub suppress_output: Option<bool>,
59}
60
61#[derive(Debug, Clone, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct PostToolUseInput {
65 pub session_id: String,
67 pub timestamp: i64,
69 pub cwd: PathBuf,
71 pub tool_name: String,
73 pub tool_args: Value,
75 pub tool_result: Value,
77}
78
79#[derive(Debug, Clone, Default, Serialize)]
81#[serde(rename_all = "camelCase")]
82pub struct PostToolUseOutput {
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub modified_result: Option<Value>,
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub additional_context: Option<String>,
89 #[serde(skip_serializing_if = "Option::is_none")]
91 pub suppress_output: Option<bool>,
92}
93
94#[derive(Debug, Clone, Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub struct UserPromptSubmittedInput {
98 pub session_id: String,
100 pub timestamp: i64,
102 pub cwd: PathBuf,
104 pub prompt: String,
106}
107
108#[derive(Debug, Clone, Default, Serialize)]
110#[serde(rename_all = "camelCase")]
111pub struct UserPromptSubmittedOutput {
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub modified_prompt: Option<String>,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub additional_context: Option<String>,
118 #[serde(skip_serializing_if = "Option::is_none")]
120 pub suppress_output: Option<bool>,
121}
122
123#[derive(Debug, Clone, Deserialize)]
125#[serde(rename_all = "camelCase")]
126pub struct SessionStartInput {
127 pub session_id: String,
129 pub timestamp: i64,
131 pub cwd: PathBuf,
133 pub source: String,
135 #[serde(default)]
137 pub initial_prompt: Option<String>,
138}
139
140#[derive(Debug, Clone, Default, Serialize)]
142#[serde(rename_all = "camelCase")]
143pub struct SessionStartOutput {
144 #[serde(skip_serializing_if = "Option::is_none")]
146 pub additional_context: Option<String>,
147 #[serde(skip_serializing_if = "Option::is_none")]
149 pub modified_config: Option<Value>,
150}
151
152#[derive(Debug, Clone, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct SessionEndInput {
156 pub session_id: String,
158 pub timestamp: i64,
160 pub cwd: PathBuf,
162 pub reason: String,
164 #[serde(default)]
166 pub final_message: Option<String>,
167 #[serde(default)]
169 pub error: Option<String>,
170}
171
172#[derive(Debug, Clone, Default, Serialize)]
174#[serde(rename_all = "camelCase")]
175pub struct SessionEndOutput {
176 #[serde(skip_serializing_if = "Option::is_none")]
178 pub suppress_output: Option<bool>,
179 #[serde(skip_serializing_if = "Option::is_none")]
181 pub cleanup_actions: Option<Vec<String>>,
182 #[serde(skip_serializing_if = "Option::is_none")]
184 pub session_summary: Option<String>,
185}
186
187#[derive(Debug, Clone, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct ErrorOccurredInput {
191 pub session_id: String,
193 pub timestamp: i64,
195 pub cwd: PathBuf,
197 pub error: String,
199 pub error_context: String,
201 pub recoverable: bool,
203}
204
205#[derive(Debug, Clone, Default, Serialize)]
207#[serde(rename_all = "camelCase")]
208pub struct ErrorOccurredOutput {
209 #[serde(skip_serializing_if = "Option::is_none")]
211 pub suppress_output: Option<bool>,
212 #[serde(skip_serializing_if = "Option::is_none")]
214 pub error_handling: Option<String>,
215 #[serde(skip_serializing_if = "Option::is_none")]
217 pub retry_count: Option<u32>,
218 #[serde(skip_serializing_if = "Option::is_none")]
220 pub user_notification: Option<String>,
221}
222
223#[non_exhaustive]
229#[derive(Debug)]
230pub enum HookEvent {
231 PreToolUse {
233 input: PreToolUseInput,
235 ctx: HookContext,
237 },
238 PostToolUse {
240 input: PostToolUseInput,
242 ctx: HookContext,
244 },
245 UserPromptSubmitted {
247 input: UserPromptSubmittedInput,
249 ctx: HookContext,
251 },
252 SessionStart {
254 input: SessionStartInput,
256 ctx: HookContext,
258 },
259 SessionEnd {
261 input: SessionEndInput,
263 ctx: HookContext,
265 },
266 ErrorOccurred {
268 input: ErrorOccurredInput,
270 ctx: HookContext,
272 },
273}
274
275#[non_exhaustive]
280#[derive(Debug)]
281pub enum HookOutput {
282 None,
284 PreToolUse(PreToolUseOutput),
286 PostToolUse(PostToolUseOutput),
288 UserPromptSubmitted(UserPromptSubmittedOutput),
290 SessionStart(SessionStartOutput),
292 SessionEnd(SessionEndOutput),
294 ErrorOccurred(ErrorOccurredOutput),
296}
297
298impl HookOutput {
299 fn variant_name(&self) -> &'static str {
300 match self {
301 Self::None => "None",
302 Self::PreToolUse(_) => "PreToolUse",
303 Self::PostToolUse(_) => "PostToolUse",
304 Self::UserPromptSubmitted(_) => "UserPromptSubmitted",
305 Self::SessionStart(_) => "SessionStart",
306 Self::SessionEnd(_) => "SessionEnd",
307 Self::ErrorOccurred(_) => "ErrorOccurred",
308 }
309 }
310}
311
312#[async_trait]
330pub trait SessionHooks: Send + Sync + 'static {
331 async fn on_hook(&self, event: HookEvent) -> HookOutput {
335 match event {
336 HookEvent::PreToolUse { input, ctx } => self
337 .on_pre_tool_use(input, ctx)
338 .await
339 .map(HookOutput::PreToolUse)
340 .unwrap_or(HookOutput::None),
341 HookEvent::PostToolUse { input, ctx } => self
342 .on_post_tool_use(input, ctx)
343 .await
344 .map(HookOutput::PostToolUse)
345 .unwrap_or(HookOutput::None),
346 HookEvent::UserPromptSubmitted { input, ctx } => self
347 .on_user_prompt_submitted(input, ctx)
348 .await
349 .map(HookOutput::UserPromptSubmitted)
350 .unwrap_or(HookOutput::None),
351 HookEvent::SessionStart { input, ctx } => self
352 .on_session_start(input, ctx)
353 .await
354 .map(HookOutput::SessionStart)
355 .unwrap_or(HookOutput::None),
356 HookEvent::SessionEnd { input, ctx } => self
357 .on_session_end(input, ctx)
358 .await
359 .map(HookOutput::SessionEnd)
360 .unwrap_or(HookOutput::None),
361 HookEvent::ErrorOccurred { input, ctx } => self
362 .on_error_occurred(input, ctx)
363 .await
364 .map(HookOutput::ErrorOccurred)
365 .unwrap_or(HookOutput::None),
366 }
367 }
368
369 async fn on_pre_tool_use(
372 &self,
373 _input: PreToolUseInput,
374 _ctx: HookContext,
375 ) -> Option<PreToolUseOutput> {
376 None
377 }
378
379 async fn on_post_tool_use(
383 &self,
384 _input: PostToolUseInput,
385 _ctx: HookContext,
386 ) -> Option<PostToolUseOutput> {
387 None
388 }
389
390 async fn on_user_prompt_submitted(
394 &self,
395 _input: UserPromptSubmittedInput,
396 _ctx: HookContext,
397 ) -> Option<UserPromptSubmittedOutput> {
398 None
399 }
400
401 async fn on_session_start(
404 &self,
405 _input: SessionStartInput,
406 _ctx: HookContext,
407 ) -> Option<SessionStartOutput> {
408 None
409 }
410
411 async fn on_session_end(
414 &self,
415 _input: SessionEndInput,
416 _ctx: HookContext,
417 ) -> Option<SessionEndOutput> {
418 None
419 }
420
421 async fn on_error_occurred(
424 &self,
425 _input: ErrorOccurredInput,
426 _ctx: HookContext,
427 ) -> Option<ErrorOccurredOutput> {
428 None
429 }
430}
431
432pub(crate) async fn dispatch_hook(
438 hooks: &dyn SessionHooks,
439 session_id: &SessionId,
440 hook_type: &str,
441 raw_input: Value,
442) -> Result<Value, crate::Error> {
443 let ctx = HookContext {
444 session_id: session_id.clone(),
445 };
446
447 let event = match hook_type {
448 "preToolUse" => {
449 let input: PreToolUseInput = serde_json::from_value(raw_input)?;
450 HookEvent::PreToolUse { input, ctx }
451 }
452 "postToolUse" => {
453 let input: PostToolUseInput = serde_json::from_value(raw_input)?;
454 HookEvent::PostToolUse { input, ctx }
455 }
456 "userPromptSubmitted" => {
457 let input: UserPromptSubmittedInput = serde_json::from_value(raw_input)?;
458 HookEvent::UserPromptSubmitted { input, ctx }
459 }
460 "sessionStart" => {
461 let input: SessionStartInput = serde_json::from_value(raw_input)?;
462 HookEvent::SessionStart { input, ctx }
463 }
464 "sessionEnd" => {
465 let input: SessionEndInput = serde_json::from_value(raw_input)?;
466 HookEvent::SessionEnd { input, ctx }
467 }
468 "errorOccurred" => {
469 let input: ErrorOccurredInput = serde_json::from_value(raw_input)?;
470 HookEvent::ErrorOccurred { input, ctx }
471 }
472 _ => {
473 tracing::warn!(
474 hook_type = hook_type,
475 session_id = %session_id,
476 "unknown hook type"
477 );
478 return Ok(serde_json::json!({ "output": {} }));
479 }
480 };
481
482 let dispatch_start = Instant::now();
483 let output = hooks.on_hook(event).await;
484 tracing::debug!(
485 elapsed_ms = dispatch_start.elapsed().as_millis(),
486 session_id = %session_id,
487 hook_type = hook_type,
488 "SessionHooks::on_hook dispatch"
489 );
490
491 let output_value = match (hook_type, &output) {
496 (_, HookOutput::None) => None,
497 ("preToolUse", HookOutput::PreToolUse(o)) => Some(serde_json::to_value(o)?),
498 ("postToolUse", HookOutput::PostToolUse(o)) => Some(serde_json::to_value(o)?),
499 ("userPromptSubmitted", HookOutput::UserPromptSubmitted(o)) => {
500 Some(serde_json::to_value(o)?)
501 }
502 ("sessionStart", HookOutput::SessionStart(o)) => Some(serde_json::to_value(o)?),
503 ("sessionEnd", HookOutput::SessionEnd(o)) => Some(serde_json::to_value(o)?),
504 ("errorOccurred", HookOutput::ErrorOccurred(o)) => Some(serde_json::to_value(o)?),
505 _ => {
506 tracing::warn!(
507 hook_type = hook_type,
508 session_id = %session_id,
509 output_variant = output.variant_name(),
510 "hook returned mismatched output variant, treating as unregistered"
511 );
512 None
513 }
514 };
515
516 Ok(serde_json::json!({ "output": output_value.unwrap_or(Value::Object(Default::default())) }))
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 struct TestHooks;
524
525 #[async_trait]
526 impl SessionHooks for TestHooks {
527 async fn on_hook(&self, event: HookEvent) -> HookOutput {
528 match event {
529 HookEvent::PreToolUse { input, .. } => {
530 if input.tool_name == "dangerous_tool" {
531 HookOutput::PreToolUse(PreToolUseOutput {
532 permission_decision: Some("deny".to_string()),
533 permission_decision_reason: Some("blocked by policy".to_string()),
534 ..Default::default()
535 })
536 } else {
537 HookOutput::None
538 }
539 }
540 HookEvent::UserPromptSubmitted { input, .. } => {
541 HookOutput::UserPromptSubmitted(UserPromptSubmittedOutput {
542 modified_prompt: Some(format!("[prefixed] {}", input.prompt)),
543 ..Default::default()
544 })
545 }
546 _ => HookOutput::None,
547 }
548 }
549 }
550
551 #[tokio::test]
552 async fn dispatch_pre_tool_use_deny() {
553 let hooks = TestHooks;
554 let input = serde_json::json!({
555 "sessionId": "sess-1",
556 "timestamp": 1234567890,
557 "cwd": "/tmp",
558 "toolName": "dangerous_tool",
559 "toolArgs": {}
560 });
561 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
562 .await
563 .unwrap();
564 let output = &result["output"];
565 assert_eq!(output["permissionDecision"], "deny");
566 assert_eq!(output["permissionDecisionReason"], "blocked by policy");
567 }
568
569 #[tokio::test]
570 async fn dispatch_pre_tool_use_passthrough() {
571 let hooks = TestHooks;
572 let input = serde_json::json!({
573 "sessionId": "sess-1",
574 "timestamp": 1234567890,
575 "cwd": "/tmp",
576 "toolName": "safe_tool",
577 "toolArgs": {"key": "value"}
578 });
579 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
580 .await
581 .unwrap();
582 assert_eq!(result["output"], serde_json::json!({}));
584 }
585
586 #[tokio::test]
587 async fn dispatch_user_prompt_submitted() {
588 let hooks = TestHooks;
589 let input = serde_json::json!({
590 "sessionId": "sess-1",
591 "timestamp": 1234567890,
592 "cwd": "/tmp",
593 "prompt": "hello world"
594 });
595 let result = dispatch_hook(
596 &hooks,
597 &SessionId::new("sess-1"),
598 "userPromptSubmitted",
599 input,
600 )
601 .await
602 .unwrap();
603 assert_eq!(result["output"]["modifiedPrompt"], "[prefixed] hello world");
604 }
605
606 #[tokio::test]
607 async fn dispatch_unregistered_hook_returns_empty() {
608 let hooks = TestHooks;
609 let input = serde_json::json!({
610 "sessionId": "sess-1",
611 "timestamp": 1234567890,
612 "cwd": "/tmp",
613 "reason": "complete"
614 });
615 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionEnd", input)
617 .await
618 .unwrap();
619 assert_eq!(result["output"], serde_json::json!({}));
620 }
621
622 #[tokio::test]
623 async fn dispatch_unknown_hook_type() {
624 let hooks = TestHooks;
625 let input = serde_json::json!({});
626 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "unknownHook", input)
627 .await
628 .unwrap();
629 assert_eq!(result["output"], serde_json::json!({}));
630 }
631
632 #[tokio::test]
633 async fn dispatch_mismatched_output_returns_empty() {
634 struct MismatchHooks;
635 #[async_trait]
636 impl SessionHooks for MismatchHooks {
637 async fn on_hook(&self, _event: HookEvent) -> HookOutput {
638 HookOutput::SessionEnd(SessionEndOutput {
640 session_summary: Some("oops".to_string()),
641 ..Default::default()
642 })
643 }
644 }
645
646 let hooks = MismatchHooks;
647 let input = serde_json::json!({
648 "sessionId": "sess-1",
649 "timestamp": 1234567890,
650 "cwd": "/tmp",
651 "toolName": "some_tool",
652 "toolArgs": {}
653 });
654 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
656 .await
657 .unwrap();
658 assert_eq!(result["output"], serde_json::json!({}));
659 }
660
661 #[tokio::test]
662 async fn dispatch_post_tool_use_default() {
663 let hooks = TestHooks;
664 let input = serde_json::json!({
665 "sessionId": "sess-1",
666 "timestamp": 1234567890,
667 "cwd": "/tmp",
668 "toolName": "some_tool",
669 "toolArgs": {},
670 "toolResult": "success"
671 });
672 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "postToolUse", input)
673 .await
674 .unwrap();
675 assert_eq!(result["output"], serde_json::json!({}));
676 }
677
678 #[tokio::test]
679 async fn dispatch_session_start() {
680 struct StartHooks;
681 #[async_trait]
682 impl SessionHooks for StartHooks {
683 async fn on_hook(&self, event: HookEvent) -> HookOutput {
684 match event {
685 HookEvent::SessionStart { .. } => {
686 HookOutput::SessionStart(SessionStartOutput {
687 additional_context: Some("extra context".to_string()),
688 ..Default::default()
689 })
690 }
691 _ => HookOutput::None,
692 }
693 }
694 }
695
696 let hooks = StartHooks;
697 let input = serde_json::json!({
698 "sessionId": "sess-1",
699 "timestamp": 1234567890,
700 "cwd": "/tmp",
701 "source": "new"
702 });
703 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionStart", input)
704 .await
705 .unwrap();
706 assert_eq!(result["output"]["additionalContext"], "extra context");
707 }
708
709 #[tokio::test]
710 async fn dispatch_error_occurred() {
711 struct ErrorHooks;
712 #[async_trait]
713 impl SessionHooks for ErrorHooks {
714 async fn on_hook(&self, event: HookEvent) -> HookOutput {
715 match event {
716 HookEvent::ErrorOccurred { .. } => {
717 HookOutput::ErrorOccurred(ErrorOccurredOutput {
718 error_handling: Some("retry".to_string()),
719 retry_count: Some(3),
720 ..Default::default()
721 })
722 }
723 _ => HookOutput::None,
724 }
725 }
726 }
727
728 let hooks = ErrorHooks;
729 let input = serde_json::json!({
730 "sessionId": "sess-1",
731 "timestamp": 1234567890,
732 "cwd": "/tmp",
733 "error": "model timeout",
734 "errorContext": "model_call",
735 "recoverable": true
736 });
737 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "errorOccurred", input)
738 .await
739 .unwrap();
740 assert_eq!(result["output"]["errorHandling"], "retry");
741 assert_eq!(result["output"]["retryCount"], 3);
742 }
743}