1use std::path::PathBuf;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14use crate::types::SessionId;
15
16#[derive(Debug, Clone)]
18pub struct HookContext {
19 pub session_id: SessionId,
21}
22
23#[derive(Debug, Clone, Deserialize)]
25#[serde(rename_all = "camelCase")]
26pub struct PreToolUseInput {
27 pub timestamp: i64,
29 pub cwd: PathBuf,
31 pub tool_name: String,
33 pub tool_args: Value,
35}
36
37#[derive(Debug, Clone, Default, Serialize)]
39#[serde(rename_all = "camelCase")]
40pub struct PreToolUseOutput {
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub permission_decision: Option<String>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub permission_decision_reason: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub modified_args: Option<Value>,
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub additional_context: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub suppress_output: Option<bool>,
56}
57
58#[derive(Debug, Clone, Deserialize)]
60#[serde(rename_all = "camelCase")]
61pub struct PostToolUseInput {
62 pub timestamp: i64,
64 pub cwd: PathBuf,
66 pub tool_name: String,
68 pub tool_args: Value,
70 pub tool_result: Value,
72}
73
74#[derive(Debug, Clone, Default, Serialize)]
76#[serde(rename_all = "camelCase")]
77pub struct PostToolUseOutput {
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub modified_result: Option<Value>,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub additional_context: Option<String>,
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub suppress_output: Option<bool>,
87}
88
89#[derive(Debug, Clone, Deserialize)]
91#[serde(rename_all = "camelCase")]
92pub struct UserPromptSubmittedInput {
93 pub timestamp: i64,
95 pub cwd: PathBuf,
97 pub prompt: String,
99}
100
101#[derive(Debug, Clone, Default, Serialize)]
103#[serde(rename_all = "camelCase")]
104pub struct UserPromptSubmittedOutput {
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub modified_prompt: Option<String>,
108 #[serde(skip_serializing_if = "Option::is_none")]
110 pub additional_context: Option<String>,
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub suppress_output: Option<bool>,
114}
115
116#[derive(Debug, Clone, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct SessionStartInput {
120 pub timestamp: i64,
122 pub cwd: PathBuf,
124 pub source: String,
126 #[serde(default)]
128 pub initial_prompt: Option<String>,
129}
130
131#[derive(Debug, Clone, Default, Serialize)]
133#[serde(rename_all = "camelCase")]
134pub struct SessionStartOutput {
135 #[serde(skip_serializing_if = "Option::is_none")]
137 pub additional_context: Option<String>,
138 #[serde(skip_serializing_if = "Option::is_none")]
140 pub modified_config: Option<Value>,
141}
142
143#[derive(Debug, Clone, Deserialize)]
145#[serde(rename_all = "camelCase")]
146pub struct SessionEndInput {
147 pub timestamp: i64,
149 pub cwd: PathBuf,
151 pub reason: String,
153 #[serde(default)]
155 pub final_message: Option<String>,
156 #[serde(default)]
158 pub error: Option<String>,
159}
160
161#[derive(Debug, Clone, Default, Serialize)]
163#[serde(rename_all = "camelCase")]
164pub struct SessionEndOutput {
165 #[serde(skip_serializing_if = "Option::is_none")]
167 pub suppress_output: Option<bool>,
168 #[serde(skip_serializing_if = "Option::is_none")]
170 pub cleanup_actions: Option<Vec<String>>,
171 #[serde(skip_serializing_if = "Option::is_none")]
173 pub session_summary: Option<String>,
174}
175
176#[derive(Debug, Clone, Deserialize)]
178#[serde(rename_all = "camelCase")]
179pub struct ErrorOccurredInput {
180 pub timestamp: i64,
182 pub cwd: PathBuf,
184 pub error: String,
186 pub error_context: String,
188 pub recoverable: bool,
190}
191
192#[derive(Debug, Clone, Default, Serialize)]
194#[serde(rename_all = "camelCase")]
195pub struct ErrorOccurredOutput {
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub suppress_output: Option<bool>,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 pub error_handling: Option<String>,
202 #[serde(skip_serializing_if = "Option::is_none")]
204 pub retry_count: Option<u32>,
205 #[serde(skip_serializing_if = "Option::is_none")]
207 pub user_notification: Option<String>,
208}
209
210#[non_exhaustive]
216#[derive(Debug)]
217pub enum HookEvent {
218 PreToolUse {
220 input: PreToolUseInput,
222 ctx: HookContext,
224 },
225 PostToolUse {
227 input: PostToolUseInput,
229 ctx: HookContext,
231 },
232 UserPromptSubmitted {
234 input: UserPromptSubmittedInput,
236 ctx: HookContext,
238 },
239 SessionStart {
241 input: SessionStartInput,
243 ctx: HookContext,
245 },
246 SessionEnd {
248 input: SessionEndInput,
250 ctx: HookContext,
252 },
253 ErrorOccurred {
255 input: ErrorOccurredInput,
257 ctx: HookContext,
259 },
260}
261
262#[non_exhaustive]
267#[derive(Debug)]
268pub enum HookOutput {
269 None,
271 PreToolUse(PreToolUseOutput),
273 PostToolUse(PostToolUseOutput),
275 UserPromptSubmitted(UserPromptSubmittedOutput),
277 SessionStart(SessionStartOutput),
279 SessionEnd(SessionEndOutput),
281 ErrorOccurred(ErrorOccurredOutput),
283}
284
285impl HookOutput {
286 fn variant_name(&self) -> &'static str {
287 match self {
288 Self::None => "None",
289 Self::PreToolUse(_) => "PreToolUse",
290 Self::PostToolUse(_) => "PostToolUse",
291 Self::UserPromptSubmitted(_) => "UserPromptSubmitted",
292 Self::SessionStart(_) => "SessionStart",
293 Self::SessionEnd(_) => "SessionEnd",
294 Self::ErrorOccurred(_) => "ErrorOccurred",
295 }
296 }
297}
298
299#[async_trait]
317pub trait SessionHooks: Send + Sync + 'static {
318 async fn on_hook(&self, event: HookEvent) -> HookOutput {
322 match event {
323 HookEvent::PreToolUse { input, ctx } => self
324 .on_pre_tool_use(input, ctx)
325 .await
326 .map(HookOutput::PreToolUse)
327 .unwrap_or(HookOutput::None),
328 HookEvent::PostToolUse { input, ctx } => self
329 .on_post_tool_use(input, ctx)
330 .await
331 .map(HookOutput::PostToolUse)
332 .unwrap_or(HookOutput::None),
333 HookEvent::UserPromptSubmitted { input, ctx } => self
334 .on_user_prompt_submitted(input, ctx)
335 .await
336 .map(HookOutput::UserPromptSubmitted)
337 .unwrap_or(HookOutput::None),
338 HookEvent::SessionStart { input, ctx } => self
339 .on_session_start(input, ctx)
340 .await
341 .map(HookOutput::SessionStart)
342 .unwrap_or(HookOutput::None),
343 HookEvent::SessionEnd { input, ctx } => self
344 .on_session_end(input, ctx)
345 .await
346 .map(HookOutput::SessionEnd)
347 .unwrap_or(HookOutput::None),
348 HookEvent::ErrorOccurred { input, ctx } => self
349 .on_error_occurred(input, ctx)
350 .await
351 .map(HookOutput::ErrorOccurred)
352 .unwrap_or(HookOutput::None),
353 }
354 }
355
356 async fn on_pre_tool_use(
359 &self,
360 _input: PreToolUseInput,
361 _ctx: HookContext,
362 ) -> Option<PreToolUseOutput> {
363 None
364 }
365
366 async fn on_post_tool_use(
370 &self,
371 _input: PostToolUseInput,
372 _ctx: HookContext,
373 ) -> Option<PostToolUseOutput> {
374 None
375 }
376
377 async fn on_user_prompt_submitted(
381 &self,
382 _input: UserPromptSubmittedInput,
383 _ctx: HookContext,
384 ) -> Option<UserPromptSubmittedOutput> {
385 None
386 }
387
388 async fn on_session_start(
391 &self,
392 _input: SessionStartInput,
393 _ctx: HookContext,
394 ) -> Option<SessionStartOutput> {
395 None
396 }
397
398 async fn on_session_end(
401 &self,
402 _input: SessionEndInput,
403 _ctx: HookContext,
404 ) -> Option<SessionEndOutput> {
405 None
406 }
407
408 async fn on_error_occurred(
411 &self,
412 _input: ErrorOccurredInput,
413 _ctx: HookContext,
414 ) -> Option<ErrorOccurredOutput> {
415 None
416 }
417}
418
419pub(crate) async fn dispatch_hook(
425 hooks: &dyn SessionHooks,
426 session_id: &SessionId,
427 hook_type: &str,
428 raw_input: Value,
429) -> Result<Value, crate::Error> {
430 let ctx = HookContext {
431 session_id: session_id.clone(),
432 };
433
434 let event = match hook_type {
435 "preToolUse" => {
436 let input: PreToolUseInput = serde_json::from_value(raw_input)?;
437 HookEvent::PreToolUse { input, ctx }
438 }
439 "postToolUse" => {
440 let input: PostToolUseInput = serde_json::from_value(raw_input)?;
441 HookEvent::PostToolUse { input, ctx }
442 }
443 "userPromptSubmitted" => {
444 let input: UserPromptSubmittedInput = serde_json::from_value(raw_input)?;
445 HookEvent::UserPromptSubmitted { input, ctx }
446 }
447 "sessionStart" => {
448 let input: SessionStartInput = serde_json::from_value(raw_input)?;
449 HookEvent::SessionStart { input, ctx }
450 }
451 "sessionEnd" => {
452 let input: SessionEndInput = serde_json::from_value(raw_input)?;
453 HookEvent::SessionEnd { input, ctx }
454 }
455 "errorOccurred" => {
456 let input: ErrorOccurredInput = serde_json::from_value(raw_input)?;
457 HookEvent::ErrorOccurred { input, ctx }
458 }
459 _ => {
460 tracing::warn!(
461 hook_type = hook_type,
462 session_id = %session_id,
463 "unknown hook type"
464 );
465 return Ok(serde_json::json!({ "output": {} }));
466 }
467 };
468
469 let output = hooks.on_hook(event).await;
470
471 let output_value = match (hook_type, &output) {
476 (_, HookOutput::None) => None,
477 ("preToolUse", HookOutput::PreToolUse(o)) => Some(serde_json::to_value(o)?),
478 ("postToolUse", HookOutput::PostToolUse(o)) => Some(serde_json::to_value(o)?),
479 ("userPromptSubmitted", HookOutput::UserPromptSubmitted(o)) => {
480 Some(serde_json::to_value(o)?)
481 }
482 ("sessionStart", HookOutput::SessionStart(o)) => Some(serde_json::to_value(o)?),
483 ("sessionEnd", HookOutput::SessionEnd(o)) => Some(serde_json::to_value(o)?),
484 ("errorOccurred", HookOutput::ErrorOccurred(o)) => Some(serde_json::to_value(o)?),
485 _ => {
486 tracing::warn!(
487 hook_type = hook_type,
488 session_id = %session_id,
489 output_variant = output.variant_name(),
490 "hook returned mismatched output variant, treating as unregistered"
491 );
492 None
493 }
494 };
495
496 Ok(serde_json::json!({ "output": output_value.unwrap_or(Value::Object(Default::default())) }))
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 struct TestHooks;
504
505 #[async_trait]
506 impl SessionHooks for TestHooks {
507 async fn on_hook(&self, event: HookEvent) -> HookOutput {
508 match event {
509 HookEvent::PreToolUse { input, .. } => {
510 if input.tool_name == "dangerous_tool" {
511 HookOutput::PreToolUse(PreToolUseOutput {
512 permission_decision: Some("deny".to_string()),
513 permission_decision_reason: Some("blocked by policy".to_string()),
514 ..Default::default()
515 })
516 } else {
517 HookOutput::None
518 }
519 }
520 HookEvent::UserPromptSubmitted { input, .. } => {
521 HookOutput::UserPromptSubmitted(UserPromptSubmittedOutput {
522 modified_prompt: Some(format!("[prefixed] {}", input.prompt)),
523 ..Default::default()
524 })
525 }
526 _ => HookOutput::None,
527 }
528 }
529 }
530
531 #[tokio::test]
532 async fn dispatch_pre_tool_use_deny() {
533 let hooks = TestHooks;
534 let input = serde_json::json!({
535 "timestamp": 1234567890,
536 "cwd": "/tmp",
537 "toolName": "dangerous_tool",
538 "toolArgs": {}
539 });
540 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
541 .await
542 .unwrap();
543 let output = &result["output"];
544 assert_eq!(output["permissionDecision"], "deny");
545 assert_eq!(output["permissionDecisionReason"], "blocked by policy");
546 }
547
548 #[tokio::test]
549 async fn dispatch_pre_tool_use_passthrough() {
550 let hooks = TestHooks;
551 let input = serde_json::json!({
552 "timestamp": 1234567890,
553 "cwd": "/tmp",
554 "toolName": "safe_tool",
555 "toolArgs": {"key": "value"}
556 });
557 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
558 .await
559 .unwrap();
560 assert_eq!(result["output"], serde_json::json!({}));
562 }
563
564 #[tokio::test]
565 async fn dispatch_user_prompt_submitted() {
566 let hooks = TestHooks;
567 let input = serde_json::json!({
568 "timestamp": 1234567890,
569 "cwd": "/tmp",
570 "prompt": "hello world"
571 });
572 let result = dispatch_hook(
573 &hooks,
574 &SessionId::new("sess-1"),
575 "userPromptSubmitted",
576 input,
577 )
578 .await
579 .unwrap();
580 assert_eq!(result["output"]["modifiedPrompt"], "[prefixed] hello world");
581 }
582
583 #[tokio::test]
584 async fn dispatch_unregistered_hook_returns_empty() {
585 let hooks = TestHooks;
586 let input = serde_json::json!({
587 "timestamp": 1234567890,
588 "cwd": "/tmp",
589 "reason": "complete"
590 });
591 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionEnd", input)
593 .await
594 .unwrap();
595 assert_eq!(result["output"], serde_json::json!({}));
596 }
597
598 #[tokio::test]
599 async fn dispatch_unknown_hook_type() {
600 let hooks = TestHooks;
601 let input = serde_json::json!({});
602 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "unknownHook", input)
603 .await
604 .unwrap();
605 assert_eq!(result["output"], serde_json::json!({}));
606 }
607
608 #[tokio::test]
609 async fn dispatch_mismatched_output_returns_empty() {
610 struct MismatchHooks;
611 #[async_trait]
612 impl SessionHooks for MismatchHooks {
613 async fn on_hook(&self, _event: HookEvent) -> HookOutput {
614 HookOutput::SessionEnd(SessionEndOutput {
616 session_summary: Some("oops".to_string()),
617 ..Default::default()
618 })
619 }
620 }
621
622 let hooks = MismatchHooks;
623 let input = serde_json::json!({
624 "timestamp": 1234567890,
625 "cwd": "/tmp",
626 "toolName": "some_tool",
627 "toolArgs": {}
628 });
629 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input)
631 .await
632 .unwrap();
633 assert_eq!(result["output"], serde_json::json!({}));
634 }
635
636 #[tokio::test]
637 async fn dispatch_post_tool_use_default() {
638 let hooks = TestHooks;
639 let input = serde_json::json!({
640 "timestamp": 1234567890,
641 "cwd": "/tmp",
642 "toolName": "some_tool",
643 "toolArgs": {},
644 "toolResult": "success"
645 });
646 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "postToolUse", input)
647 .await
648 .unwrap();
649 assert_eq!(result["output"], serde_json::json!({}));
650 }
651
652 #[tokio::test]
653 async fn dispatch_session_start() {
654 struct StartHooks;
655 #[async_trait]
656 impl SessionHooks for StartHooks {
657 async fn on_hook(&self, event: HookEvent) -> HookOutput {
658 match event {
659 HookEvent::SessionStart { .. } => {
660 HookOutput::SessionStart(SessionStartOutput {
661 additional_context: Some("extra context".to_string()),
662 ..Default::default()
663 })
664 }
665 _ => HookOutput::None,
666 }
667 }
668 }
669
670 let hooks = StartHooks;
671 let input = serde_json::json!({
672 "timestamp": 1234567890,
673 "cwd": "/tmp",
674 "source": "new"
675 });
676 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionStart", input)
677 .await
678 .unwrap();
679 assert_eq!(result["output"]["additionalContext"], "extra context");
680 }
681
682 #[tokio::test]
683 async fn dispatch_error_occurred() {
684 struct ErrorHooks;
685 #[async_trait]
686 impl SessionHooks for ErrorHooks {
687 async fn on_hook(&self, event: HookEvent) -> HookOutput {
688 match event {
689 HookEvent::ErrorOccurred { .. } => {
690 HookOutput::ErrorOccurred(ErrorOccurredOutput {
691 error_handling: Some("retry".to_string()),
692 retry_count: Some(3),
693 ..Default::default()
694 })
695 }
696 _ => HookOutput::None,
697 }
698 }
699 }
700
701 let hooks = ErrorHooks;
702 let input = serde_json::json!({
703 "timestamp": 1234567890,
704 "cwd": "/tmp",
705 "error": "model timeout",
706 "errorContext": "model_call",
707 "recoverable": true
708 });
709 let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "errorOccurred", input)
710 .await
711 .unwrap();
712 assert_eq!(result["output"]["errorHandling"], "retry");
713 assert_eq!(result["output"]["retryCount"], 3);
714 }
715}