1use std::sync::Arc;
33
34use async_trait::async_trait;
35use serde::{Deserialize, Serialize};
36
37use crate::tool::ToolOutput;
38use crate::types::{SessionId, ToolCallId};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
59#[serde(rename_all = "snake_case")]
60pub enum HookEvent {
61 PreToolUse,
65 PostToolUse,
67 PostToolFailure,
69
70 UserPromptSubmit,
74
75 SessionStart,
79 SessionEnd,
81
82 Stop,
86 SubAgentStart,
88
89 PreCompact,
93 PostCompact,
95}
96
97pub const ALL_HOOK_EVENTS: &[HookEvent] = &[
99 HookEvent::PreToolUse,
100 HookEvent::PostToolUse,
101 HookEvent::PostToolFailure,
102 HookEvent::UserPromptSubmit,
103 HookEvent::SessionStart,
104 HookEvent::SessionEnd,
105 HookEvent::Stop,
106 HookEvent::SubAgentStart,
107 HookEvent::PreCompact,
108 HookEvent::PostCompact,
109];
110
111impl HookEvent {
112 pub fn is_tool_event(&self) -> bool {
114 matches!(
115 self,
116 Self::PreToolUse | Self::PostToolUse | Self::PostToolFailure
117 )
118 }
119}
120
121impl std::fmt::Display for HookEvent {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 let s = match self {
124 Self::PreToolUse => "pre_tool_use",
125 Self::PostToolUse => "post_tool_use",
126 Self::PostToolFailure => "post_tool_failure",
127 Self::UserPromptSubmit => "user_prompt_submit",
128 Self::SessionStart => "session_start",
129 Self::SessionEnd => "session_end",
130 Self::Stop => "stop",
131 Self::SubAgentStart => "sub_agent_start",
132 Self::PreCompact => "pre_compact",
133 Self::PostCompact => "post_compact",
134 };
135 f.write_str(s)
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
164#[serde(tag = "hook_event", rename_all = "snake_case")]
165pub enum HookInput {
166 PreToolUse {
168 tool_name: String,
169 tool_input: serde_json::Value,
170 call_id: ToolCallId,
171 },
172
173 PostToolUse {
175 tool_name: String,
176 tool_input: serde_json::Value,
177 tool_output: ToolOutput,
178 call_id: ToolCallId,
179 },
180
181 PostToolFailure {
183 tool_name: String,
184 tool_input: serde_json::Value,
185 error: String,
186 call_id: ToolCallId,
187 },
188
189 UserPromptSubmit {
191 prompt: String,
192 },
193
194 SessionStart {
196 session_id: SessionId,
197 },
198
199 SessionEnd {
201 session_id: SessionId,
202 reason: String,
203 },
204
205 Stop {
207 finish_reason: String,
208 },
209
210 SubAgentStart {
212 agent_name: String,
213 },
214
215 PreCompact {
217 trigger: String,
218 tokens_before: u64,
219 },
220
221 PostCompact {
223 trigger: String,
224 tokens_after: u64,
225 },
226}
227
228impl HookInput {
229 pub fn event(&self) -> HookEvent {
231 match self {
232 Self::PreToolUse { .. } => HookEvent::PreToolUse,
233 Self::PostToolUse { .. } => HookEvent::PostToolUse,
234 Self::PostToolFailure { .. } => HookEvent::PostToolFailure,
235 Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
236 Self::SessionStart { .. } => HookEvent::SessionStart,
237 Self::SessionEnd { .. } => HookEvent::SessionEnd,
238 Self::Stop { .. } => HookEvent::Stop,
239 Self::SubAgentStart { .. } => HookEvent::SubAgentStart,
240 Self::PreCompact { .. } => HookEvent::PreCompact,
241 Self::PostCompact { .. } => HookEvent::PostCompact,
242 }
243 }
244
245 pub fn tool_name(&self) -> Option<&str> {
247 match self {
248 Self::PreToolUse { tool_name, .. }
249 | Self::PostToolUse { tool_name, .. }
250 | Self::PostToolFailure { tool_name, .. } => Some(tool_name.as_str()),
251 _ => None,
252 }
253 }
254
255 pub fn call_id(&self) -> Option<&ToolCallId> {
257 match self {
258 Self::PreToolUse { call_id, .. }
259 | Self::PostToolUse { call_id, .. }
260 | Self::PostToolFailure { call_id, .. } => Some(call_id),
261 _ => None,
262 }
263 }
264}
265
266#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
289#[serde(tag = "type", rename_all = "snake_case")]
290pub enum HookPermission {
291 Allow,
293
294 Deny {
296 #[serde(default, skip_serializing_if = "Option::is_none")]
297 reason: Option<String>,
298 },
299
300 Ask {
302 #[serde(default, skip_serializing_if = "Option::is_none")]
303 message: Option<String>,
304 },
305}
306
307impl HookPermission {
308 pub fn deny() -> Self {
310 Self::Deny { reason: None }
311 }
312
313 pub fn deny_with_reason(reason: impl Into<String>) -> Self {
315 Self::Deny {
316 reason: Some(reason.into()),
317 }
318 }
319
320 pub fn ask() -> Self {
322 Self::Ask { message: None }
323 }
324
325 pub fn ask_with_message(message: impl Into<String>) -> Self {
327 Self::Ask {
328 message: Some(message.into()),
329 }
330 }
331
332 pub fn is_allow(&self) -> bool {
333 matches!(self, Self::Allow)
334 }
335
336 pub fn is_deny(&self) -> bool {
337 matches!(self, Self::Deny { .. })
338 }
339
340 pub fn is_ask(&self) -> bool {
341 matches!(self, Self::Ask { .. })
342 }
343
344 fn strictness(&self) -> u8 {
348 match self {
349 Self::Allow => 0,
350 Self::Ask { .. } => 1,
351 Self::Deny { .. } => 2,
352 }
353 }
354}
355
356#[derive(Debug, Clone, Default, Serialize, Deserialize)]
382pub struct HookOutput {
383 #[serde(default, skip_serializing_if = "Option::is_none")]
385 pub permission: Option<HookPermission>,
386
387 #[serde(default, skip_serializing_if = "Option::is_none")]
389 pub updated_input: Option<serde_json::Value>,
390
391 #[serde(default, skip_serializing_if = "Option::is_none")]
393 pub updated_output: Option<ToolOutput>,
394
395 #[serde(default, skip_serializing_if = "Vec::is_empty")]
397 pub additional_context: Vec<String>,
398
399 #[serde(default)]
401 pub prevent_continuation: bool,
402
403 #[serde(default, skip_serializing_if = "Option::is_none")]
405 pub stop_reason: Option<String>,
406
407 #[serde(default, skip_serializing_if = "Option::is_none")]
409 pub blocking_error: Option<String>,
410
411 #[serde(default, skip_serializing_if = "Option::is_none")]
413 pub system_message: Option<String>,
414}
415
416impl HookOutput {
417 pub fn passthrough() -> Self {
419 Self::default()
420 }
421
422 pub fn allow() -> Self {
424 Self {
425 permission: Some(HookPermission::Allow),
426 ..Default::default()
427 }
428 }
429
430 pub fn deny(reason: impl Into<String>) -> Self {
432 Self {
433 permission: Some(HookPermission::deny_with_reason(reason)),
434 ..Default::default()
435 }
436 }
437
438 pub fn ask(message: impl Into<String>) -> Self {
440 Self {
441 permission: Some(HookPermission::ask_with_message(message)),
442 ..Default::default()
443 }
444 }
445
446 pub fn with_updated_input(mut self, input: serde_json::Value) -> Self {
448 self.updated_input = Some(input);
449 self
450 }
451
452 pub fn with_updated_output(mut self, output: ToolOutput) -> Self {
454 self.updated_output = Some(output);
455 self
456 }
457
458 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
460 self.additional_context.push(ctx.into());
461 self
462 }
463
464 pub fn with_stop(mut self, reason: impl Into<String>) -> Self {
466 self.prevent_continuation = true;
467 self.stop_reason = Some(reason.into());
468 self
469 }
470
471 pub fn with_blocking_error(mut self, error: impl Into<String>) -> Self {
473 self.blocking_error = Some(error.into());
474 self
475 }
476
477 pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
479 self.system_message = Some(message.into());
480 self
481 }
482
483 pub fn has_decision(&self) -> bool {
485 self.permission.is_some()
486 || self.updated_input.is_some()
487 || self.updated_output.is_some()
488 || !self.additional_context.is_empty()
489 || self.prevent_continuation
490 || self.blocking_error.is_some()
491 }
492}
493
494#[async_trait]
541pub trait Hook: Send + Sync {
542 fn name(&self) -> &str;
544
545 fn events(&self) -> &[HookEvent] {
550 &[]
551 }
552
553 fn matcher(&self) -> Option<&str> {
562 None
563 }
564
565 async fn on_event(&self, input: &HookInput) -> HookOutput;
569}
570
571#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
586#[serde(tag = "type", rename_all = "snake_case")]
587pub enum HookSource {
588 Settings,
590 Project,
592 Plugin { name: String },
594 Programmatic,
596 Session,
598}
599
600pub struct HookRegistry {
638 hooks: Vec<RegisteredHook>,
639}
640
641pub struct RegisteredHook {
643 pub hook: Arc<dyn Hook>,
645 pub source: HookSource,
647 pub priority: i32,
649}
650
651impl HookRegistry {
652 pub fn new() -> Self {
654 Self { hooks: Vec::new() }
655 }
656
657 pub fn register(
661 &mut self,
662 hook: Arc<dyn Hook>,
663 source: HookSource,
664 priority: i32,
665 ) {
666 self.hooks.push(RegisteredHook {
667 hook,
668 source,
669 priority,
670 });
671 self.hooks.sort_by_key(|h| h.priority);
672 }
673
674 pub fn remove(&mut self, name: &str) {
676 self.hooks.retain(|h| h.hook.name() != name);
677 }
678
679 pub fn len(&self) -> usize {
681 self.hooks.len()
682 }
683
684 pub fn is_empty(&self) -> bool {
686 self.hooks.is_empty()
687 }
688
689 pub fn matching(&self, input: &HookInput) -> Vec<&RegisteredHook> {
695 let event = input.event();
696 let tool_name = input.tool_name();
697
698 self.hooks
699 .iter()
700 .filter(|h| {
701 let events = h.hook.events();
702 let event_match = events.is_empty() || events.contains(&event);
703 if !event_match {
704 return false;
705 }
706
707 match (h.hook.matcher(), tool_name) {
708 (Some(pattern), Some(name)) => matches_pattern(name, pattern),
709 (Some(_), None) => false,
710 (None, _) => true,
711 }
712 })
713 .collect()
714 }
715
716 pub fn has_hooks_for(&self, event: HookEvent) -> bool {
720 self.hooks.iter().any(|h| {
721 let events = h.hook.events();
722 events.is_empty() || events.contains(&event)
723 })
724 }
725}
726
727impl Default for HookRegistry {
728 fn default() -> Self {
729 Self::new()
730 }
731}
732
733#[derive(Debug, Clone, Default)]
761pub struct AggregatedHookOutput {
762 pub permission: Option<HookPermission>,
764
765 pub updated_input: Option<serde_json::Value>,
767
768 pub updated_output: Option<ToolOutput>,
770
771 pub additional_context: Vec<String>,
773
774 pub prevent_continuation: bool,
776
777 pub stop_reason: Option<String>,
779
780 pub blocking_errors: Vec<String>,
782
783 pub system_messages: Vec<String>,
785}
786
787impl AggregatedHookOutput {
788 pub fn merge(&mut self, output: HookOutput, hook_name: &str) {
790 if let Some(ref new_perm) = output.permission {
792 match &self.permission {
793 Some(existing) if existing.strictness() >= new_perm.strictness() => {
794 }
796 _ => {
797 self.permission = output.permission.clone();
798 }
799 }
800 }
801
802 if output.updated_input.is_some() {
803 self.updated_input = output.updated_input;
804 }
805 if output.updated_output.is_some() {
806 self.updated_output = output.updated_output;
807 }
808 self.additional_context.extend(output.additional_context);
809
810 if output.prevent_continuation {
811 self.prevent_continuation = true;
812 if self.stop_reason.is_none() {
813 self.stop_reason = output.stop_reason;
814 }
815 }
816
817 if let Some(err) = output.blocking_error {
818 self.blocking_errors.push(format!("[{hook_name}] {err}"));
819 }
820 if let Some(msg) = output.system_message {
821 self.system_messages.push(msg);
822 }
823 }
824
825 pub fn has_decision(&self) -> bool {
827 self.permission.is_some()
828 || self.updated_input.is_some()
829 || self.updated_output.is_some()
830 || !self.additional_context.is_empty()
831 || self.prevent_continuation
832 || !self.blocking_errors.is_empty()
833 }
834
835 pub fn has_blocking_errors(&self) -> bool {
837 !self.blocking_errors.is_empty()
838 }
839
840 pub fn is_denied(&self) -> bool {
842 matches!(&self.permission, Some(p) if p.is_deny())
843 }
844}
845
846pub fn matches_pattern(value: &str, pattern: &str) -> bool {
868 if pattern.contains('|') {
869 return pattern.split('|').any(|p| matches_single_pattern(value, p.trim()));
870 }
871 matches_single_pattern(value, pattern)
872}
873
874fn matches_single_pattern(value: &str, pattern: &str) -> bool {
876 if !pattern.contains('*') {
877 return value == pattern;
878 }
879
880 let parts: Vec<&str> = pattern.split('*').collect();
881
882 if parts.len() == 2 && parts[0].is_empty() && parts[1].is_empty() {
884 return true;
885 }
886
887 if parts.len() == 2 && parts[1].is_empty() {
889 return value.starts_with(parts[0]);
890 }
891
892 if parts.len() == 2 && parts[0].is_empty() {
894 return value.ends_with(parts[1]);
895 }
896
897 if parts.len() == 2 {
899 return value.starts_with(parts[0])
900 && value.ends_with(parts[1])
901 && value.len() >= parts[0].len() + parts[1].len();
902 }
903
904 let mut remaining = value;
906 for (i, part) in parts.iter().enumerate() {
907 if part.is_empty() {
908 continue;
909 }
910 if i == 0 {
911 if !remaining.starts_with(part) {
912 return false;
913 }
914 remaining = &remaining[part.len()..];
915 } else if let Some(pos) = remaining.find(part) {
916 remaining = &remaining[pos + part.len()..];
917 } else {
918 return false;
919 }
920 }
921 true
922}
923
924#[cfg(test)]
929mod tests {
930 use super::*;
931 use serde_json::json;
932
933 #[test]
936 fn test_hook_event_is_tool_event() {
937 assert!(HookEvent::PreToolUse.is_tool_event());
938 assert!(HookEvent::PostToolUse.is_tool_event());
939 assert!(HookEvent::PostToolFailure.is_tool_event());
940 assert!(!HookEvent::SessionStart.is_tool_event());
941 assert!(!HookEvent::Stop.is_tool_event());
942 }
943
944 #[test]
945 fn test_hook_event_display() {
946 assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
947 assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
948 assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
949 }
950
951 #[test]
952 fn test_hook_event_serde_roundtrip() {
953 for event in ALL_HOOK_EVENTS {
954 let json_str = serde_json::to_string(event).unwrap();
955 let restored: HookEvent = serde_json::from_str(&json_str).unwrap();
956 assert_eq!(*event, restored);
957 }
958 }
959
960 #[test]
961 fn test_all_hook_events_count() {
962 assert_eq!(ALL_HOOK_EVENTS.len(), 10);
963 }
964
965 #[test]
968 fn test_hook_input_event() {
969 let input = HookInput::PreToolUse {
970 tool_name: "bash".into(),
971 tool_input: json!({}),
972 call_id: ToolCallId::new("c1"),
973 };
974 assert_eq!(input.event(), HookEvent::PreToolUse);
975 }
976
977 #[test]
978 fn test_hook_input_tool_name() {
979 let tool_input = HookInput::PreToolUse {
980 tool_name: "bash".into(),
981 tool_input: json!({}),
982 call_id: ToolCallId::new("c1"),
983 };
984 assert_eq!(tool_input.tool_name(), Some("bash"));
985
986 let non_tool = HookInput::SessionStart {
987 session_id: SessionId::new(),
988 };
989 assert_eq!(non_tool.tool_name(), None);
990 }
991
992 #[test]
993 fn test_hook_input_call_id() {
994 let input = HookInput::PostToolFailure {
995 tool_name: "bash".into(),
996 tool_input: json!({}),
997 error: "exit code 1".into(),
998 call_id: ToolCallId::new("c2"),
999 };
1000 assert_eq!(input.call_id().unwrap().as_str(), "c2");
1001
1002 let non_tool = HookInput::Stop {
1003 finish_reason: "completed".into(),
1004 };
1005 assert!(non_tool.call_id().is_none());
1006 }
1007
1008 #[test]
1009 fn test_hook_input_serde_roundtrip() {
1010 let input = HookInput::PreToolUse {
1011 tool_name: "read_file".into(),
1012 tool_input: json!({"path": "/tmp/test.txt"}),
1013 call_id: ToolCallId::new("call_42"),
1014 };
1015 let json_str = serde_json::to_string(&input).unwrap();
1016 assert!(json_str.contains("pre_tool_use"));
1017 let restored: HookInput = serde_json::from_str(&json_str).unwrap();
1018 assert_eq!(restored.event(), HookEvent::PreToolUse);
1019 assert_eq!(restored.tool_name(), Some("read_file"));
1020 }
1021
1022 #[test]
1025 fn test_hook_permission_variants() {
1026 assert!(HookPermission::Allow.is_allow());
1027 assert!(HookPermission::deny().is_deny());
1028 assert!(HookPermission::ask().is_ask());
1029 }
1030
1031 #[test]
1032 fn test_hook_permission_with_reason() {
1033 let deny = HookPermission::deny_with_reason("unsafe");
1034 match deny {
1035 HookPermission::Deny { reason } => assert_eq!(reason, Some("unsafe".into())),
1036 _ => panic!("expected Deny"),
1037 }
1038 }
1039
1040 #[test]
1041 fn test_hook_permission_strictness() {
1042 assert!(HookPermission::Allow.strictness() < HookPermission::ask().strictness());
1043 assert!(HookPermission::ask().strictness() < HookPermission::deny().strictness());
1044 }
1045
1046 #[test]
1047 fn test_hook_permission_serde_roundtrip() {
1048 for perm in [
1049 HookPermission::Allow,
1050 HookPermission::deny(),
1051 HookPermission::deny_with_reason("test"),
1052 HookPermission::ask(),
1053 HookPermission::ask_with_message("confirm?"),
1054 ] {
1055 let json_str = serde_json::to_string(&perm).unwrap();
1056 let restored: HookPermission = serde_json::from_str(&json_str).unwrap();
1057 assert_eq!(perm, restored);
1058 }
1059 }
1060
1061 #[test]
1064 fn test_hook_output_passthrough() {
1065 let out = HookOutput::passthrough();
1066 assert!(!out.has_decision());
1067 assert!(out.permission.is_none());
1068 assert!(out.additional_context.is_empty());
1069 }
1070
1071 #[test]
1072 fn test_hook_output_allow() {
1073 let out = HookOutput::allow();
1074 assert!(out.has_decision());
1075 assert!(out.permission.as_ref().unwrap().is_allow());
1076 }
1077
1078 #[test]
1079 fn test_hook_output_deny() {
1080 let out = HookOutput::deny("bad command");
1081 assert!(out.has_decision());
1082 assert!(out.permission.as_ref().unwrap().is_deny());
1083 }
1084
1085 #[test]
1086 fn test_hook_output_ask() {
1087 let out = HookOutput::ask("are you sure?");
1088 assert!(out.has_decision());
1089 assert!(out.permission.as_ref().unwrap().is_ask());
1090 }
1091
1092 #[test]
1093 fn test_hook_output_builder() {
1094 let out = HookOutput::allow()
1095 .with_updated_input(json!({"command": "ls"}))
1096 .with_context("working directory: /tmp")
1097 .with_system_message("Input sanitized");
1098
1099 assert!(out.permission.as_ref().unwrap().is_allow());
1100 assert_eq!(out.updated_input.as_ref().unwrap()["command"], "ls");
1101 assert_eq!(out.additional_context.len(), 1);
1102 assert_eq!(out.system_message, Some("Input sanitized".into()));
1103 }
1104
1105 #[test]
1106 fn test_hook_output_with_stop() {
1107 let out = HookOutput::passthrough().with_stop("loop detected");
1108 assert!(out.prevent_continuation);
1109 assert_eq!(out.stop_reason, Some("loop detected".into()));
1110 assert!(out.has_decision());
1111 }
1112
1113 #[test]
1114 fn test_hook_output_with_blocking_error() {
1115 let out = HookOutput::passthrough().with_blocking_error("lint failed");
1116 assert!(out.has_decision());
1117 assert_eq!(out.blocking_error, Some("lint failed".into()));
1118 }
1119
1120 #[test]
1121 fn test_hook_output_serde_roundtrip() {
1122 let out = HookOutput::deny("test")
1123 .with_context("ctx1")
1124 .with_system_message("msg1");
1125 let json_str = serde_json::to_string(&out).unwrap();
1126 let restored: HookOutput = serde_json::from_str(&json_str).unwrap();
1127 assert_eq!(restored.additional_context, vec!["ctx1"]);
1128 assert_eq!(restored.system_message, Some("msg1".into()));
1129 }
1130
1131 #[test]
1134 fn test_aggregated_merge_permission_deny_wins() {
1135 let mut agg = AggregatedHookOutput::default();
1136
1137 agg.merge(HookOutput::allow(), "hook_a");
1138 assert!(agg.permission.as_ref().unwrap().is_allow());
1139
1140 agg.merge(HookOutput::deny("nope"), "hook_b");
1141 assert!(agg.permission.as_ref().unwrap().is_deny());
1142
1143 agg.merge(HookOutput::allow(), "hook_c");
1145 assert!(agg.permission.as_ref().unwrap().is_deny());
1146 }
1147
1148 #[test]
1149 fn test_aggregated_merge_permission_ask_beats_allow() {
1150 let mut agg = AggregatedHookOutput::default();
1151
1152 agg.merge(HookOutput::allow(), "hook_a");
1153 agg.merge(HookOutput::ask("confirm?"), "hook_b");
1154 assert!(agg.permission.as_ref().unwrap().is_ask());
1155
1156 agg.merge(HookOutput::allow(), "hook_c");
1158 assert!(agg.permission.as_ref().unwrap().is_ask());
1159 }
1160
1161 #[test]
1162 fn test_aggregated_merge_context() {
1163 let mut agg = AggregatedHookOutput::default();
1164
1165 agg.merge(
1166 HookOutput::passthrough().with_context("ctx1"),
1167 "hook_a",
1168 );
1169 agg.merge(
1170 HookOutput::passthrough().with_context("ctx2"),
1171 "hook_b",
1172 );
1173 assert_eq!(agg.additional_context, vec!["ctx1", "ctx2"]);
1174 }
1175
1176 #[test]
1177 fn test_aggregated_merge_blocking_errors() {
1178 let mut agg = AggregatedHookOutput::default();
1179
1180 agg.merge(
1181 HookOutput::passthrough().with_blocking_error("err1"),
1182 "linter",
1183 );
1184 agg.merge(
1185 HookOutput::passthrough().with_blocking_error("err2"),
1186 "validator",
1187 );
1188 assert_eq!(agg.blocking_errors.len(), 2);
1189 assert!(agg.blocking_errors[0].contains("[linter]"));
1190 assert!(agg.blocking_errors[1].contains("[validator]"));
1191 assert!(agg.has_blocking_errors());
1192 }
1193
1194 #[test]
1195 fn test_aggregated_merge_stop() {
1196 let mut agg = AggregatedHookOutput::default();
1197
1198 agg.merge(HookOutput::passthrough(), "hook_a");
1199 assert!(!agg.prevent_continuation);
1200
1201 agg.merge(
1202 HookOutput::passthrough().with_stop("first reason"),
1203 "hook_b",
1204 );
1205 assert!(agg.prevent_continuation);
1206 assert_eq!(agg.stop_reason, Some("first reason".into()));
1207
1208 agg.merge(
1210 HookOutput::passthrough().with_stop("second reason"),
1211 "hook_c",
1212 );
1213 assert_eq!(agg.stop_reason, Some("first reason".into()));
1214 }
1215
1216 #[test]
1217 fn test_aggregated_merge_updated_input_last_wins() {
1218 let mut agg = AggregatedHookOutput::default();
1219
1220 agg.merge(
1221 HookOutput::allow().with_updated_input(json!({"a": 1})),
1222 "hook_a",
1223 );
1224 agg.merge(
1225 HookOutput::allow().with_updated_input(json!({"b": 2})),
1226 "hook_b",
1227 );
1228 assert_eq!(agg.updated_input, Some(json!({"b": 2})));
1229 }
1230
1231 #[test]
1232 fn test_aggregated_has_decision() {
1233 let agg = AggregatedHookOutput::default();
1234 assert!(!agg.has_decision());
1235
1236 let mut agg2 = AggregatedHookOutput::default();
1237 agg2.merge(HookOutput::allow(), "h");
1238 assert!(agg2.has_decision());
1239 }
1240
1241 #[test]
1242 fn test_aggregated_is_denied() {
1243 let mut agg = AggregatedHookOutput::default();
1244 assert!(!agg.is_denied());
1245
1246 agg.merge(HookOutput::deny("no"), "h");
1247 assert!(agg.is_denied());
1248 }
1249
1250 #[test]
1253 fn test_matches_pattern_exact() {
1254 assert!(matches_pattern("bash", "bash"));
1255 assert!(!matches_pattern("bash", "write_file"));
1256 }
1257
1258 #[test]
1259 fn test_matches_pattern_pipe_separated() {
1260 assert!(matches_pattern("bash", "bash|write_file"));
1261 assert!(matches_pattern("write_file", "bash|write_file"));
1262 assert!(!matches_pattern("read_file", "bash|write_file"));
1263 }
1264
1265 #[test]
1266 fn test_matches_pattern_wildcard_star() {
1267 assert!(matches_pattern("read_file", "read_*"));
1268 assert!(matches_pattern("read_dir", "read_*"));
1269 assert!(!matches_pattern("write_file", "read_*"));
1270 }
1271
1272 #[test]
1273 fn test_matches_pattern_wildcard_suffix() {
1274 assert!(matches_pattern("read_file", "*_file"));
1275 assert!(matches_pattern("write_file", "*_file"));
1276 assert!(!matches_pattern("read_dir", "*_file"));
1277 }
1278
1279 #[test]
1280 fn test_matches_pattern_wildcard_middle() {
1281 assert!(matches_pattern("pre_tool_use", "pre_*_use"));
1282 assert!(matches_pattern("pre_compact_use", "pre_*_use"));
1283 assert!(!matches_pattern("pre_tool_fail", "pre_*_use"));
1284 }
1285
1286 #[test]
1287 fn test_matches_pattern_star_matches_all() {
1288 assert!(matches_pattern("anything", "*"));
1289 assert!(matches_pattern("", "*"));
1290 }
1291
1292 #[test]
1293 fn test_matches_pattern_pipe_with_wildcard() {
1294 assert!(matches_pattern("read_file", "bash|read_*"));
1295 assert!(matches_pattern("bash", "bash|read_*"));
1296 assert!(!matches_pattern("write_file", "bash|read_*"));
1297 }
1298
1299 struct PassthroughHook {
1302 hook_name: String,
1303 hook_events: Vec<HookEvent>,
1304 hook_matcher: Option<String>,
1305 }
1306
1307 impl PassthroughHook {
1308 fn new(name: &str) -> Self {
1309 Self {
1310 hook_name: name.into(),
1311 hook_events: vec![],
1312 hook_matcher: None,
1313 }
1314 }
1315
1316 fn with_events(mut self, events: Vec<HookEvent>) -> Self {
1317 self.hook_events = events;
1318 self
1319 }
1320
1321 fn with_matcher(mut self, matcher: &str) -> Self {
1322 self.hook_matcher = Some(matcher.into());
1323 self
1324 }
1325 }
1326
1327 #[async_trait]
1328 impl Hook for PassthroughHook {
1329 fn name(&self) -> &str {
1330 &self.hook_name
1331 }
1332
1333 fn events(&self) -> &[HookEvent] {
1334 &self.hook_events
1335 }
1336
1337 fn matcher(&self) -> Option<&str> {
1338 self.hook_matcher.as_deref()
1339 }
1340
1341 async fn on_event(&self, _input: &HookInput) -> HookOutput {
1342 HookOutput::passthrough()
1343 }
1344 }
1345
1346 #[test]
1347 fn test_registry_new_empty() {
1348 let reg = HookRegistry::new();
1349 assert!(reg.is_empty());
1350 assert_eq!(reg.len(), 0);
1351 }
1352
1353 #[test]
1354 fn test_registry_register_and_len() {
1355 let mut reg = HookRegistry::new();
1356 reg.register(
1357 Arc::new(PassthroughHook::new("a")),
1358 HookSource::Programmatic,
1359 0,
1360 );
1361 reg.register(
1362 Arc::new(PassthroughHook::new("b")),
1363 HookSource::Programmatic,
1364 0,
1365 );
1366 assert_eq!(reg.len(), 2);
1367 }
1368
1369 #[test]
1370 fn test_registry_remove() {
1371 let mut reg = HookRegistry::new();
1372 reg.register(
1373 Arc::new(PassthroughHook::new("a")),
1374 HookSource::Programmatic,
1375 0,
1376 );
1377 reg.register(
1378 Arc::new(PassthroughHook::new("b")),
1379 HookSource::Programmatic,
1380 0,
1381 );
1382 reg.remove("a");
1383 assert_eq!(reg.len(), 1);
1384 assert_eq!(reg.hooks[0].hook.name(), "b");
1385 }
1386
1387 #[test]
1388 fn test_registry_matching_by_event() {
1389 let mut reg = HookRegistry::new();
1390 reg.register(
1391 Arc::new(PassthroughHook::new("pre_only").with_events(vec![HookEvent::PreToolUse])),
1392 HookSource::Programmatic,
1393 0,
1394 );
1395 reg.register(
1396 Arc::new(PassthroughHook::new("post_only").with_events(vec![HookEvent::PostToolUse])),
1397 HookSource::Programmatic,
1398 0,
1399 );
1400 reg.register(
1401 Arc::new(PassthroughHook::new("all_events")),
1402 HookSource::Programmatic,
1403 0,
1404 );
1405
1406 let input = HookInput::PreToolUse {
1407 tool_name: "bash".into(),
1408 tool_input: json!({}),
1409 call_id: ToolCallId::new("c1"),
1410 };
1411 let matched = reg.matching(&input);
1412 assert_eq!(matched.len(), 2);
1413
1414 let names: Vec<&str> = matched.iter().map(|h| h.hook.name()).collect();
1415 assert!(names.contains(&"pre_only"));
1416 assert!(names.contains(&"all_events"));
1417 assert!(!names.contains(&"post_only"));
1418 }
1419
1420 #[test]
1421 fn test_registry_matching_by_matcher() {
1422 let mut reg = HookRegistry::new();
1423 reg.register(
1424 Arc::new(
1425 PassthroughHook::new("bash_only")
1426 .with_events(vec![HookEvent::PreToolUse])
1427 .with_matcher("bash"),
1428 ),
1429 HookSource::Programmatic,
1430 0,
1431 );
1432 reg.register(
1433 Arc::new(
1434 PassthroughHook::new("write_family")
1435 .with_events(vec![HookEvent::PreToolUse])
1436 .with_matcher("write_*"),
1437 ),
1438 HookSource::Programmatic,
1439 0,
1440 );
1441
1442 let input_bash = HookInput::PreToolUse {
1444 tool_name: "bash".into(),
1445 tool_input: json!({}),
1446 call_id: ToolCallId::new("c1"),
1447 };
1448 let matched = reg.matching(&input_bash);
1449 assert_eq!(matched.len(), 1);
1450 assert_eq!(matched[0].hook.name(), "bash_only");
1451
1452 let input_write = HookInput::PreToolUse {
1454 tool_name: "write_file".into(),
1455 tool_input: json!({}),
1456 call_id: ToolCallId::new("c2"),
1457 };
1458 let matched = reg.matching(&input_write);
1459 assert_eq!(matched.len(), 1);
1460 assert_eq!(matched[0].hook.name(), "write_family");
1461
1462 let input_read = HookInput::PreToolUse {
1464 tool_name: "read_file".into(),
1465 tool_input: json!({}),
1466 call_id: ToolCallId::new("c3"),
1467 };
1468 let matched = reg.matching(&input_read);
1469 assert!(matched.is_empty());
1470 }
1471
1472 #[test]
1473 fn test_registry_matching_non_tool_event_with_matcher() {
1474 let mut reg = HookRegistry::new();
1475 reg.register(
1477 Arc::new(
1478 PassthroughHook::new("h")
1479 .with_events(vec![HookEvent::SessionStart])
1480 .with_matcher("bash"),
1481 ),
1482 HookSource::Programmatic,
1483 0,
1484 );
1485
1486 let input = HookInput::SessionStart {
1487 session_id: SessionId::new(),
1488 };
1489 let matched = reg.matching(&input);
1490 assert!(matched.is_empty());
1491 }
1492
1493 #[test]
1494 fn test_registry_priority_order() {
1495 let mut reg = HookRegistry::new();
1496 reg.register(
1497 Arc::new(PassthroughHook::new("low")),
1498 HookSource::Programmatic,
1499 10,
1500 );
1501 reg.register(
1502 Arc::new(PassthroughHook::new("high")),
1503 HookSource::Programmatic,
1504 -10,
1505 );
1506 reg.register(
1507 Arc::new(PassthroughHook::new("mid")),
1508 HookSource::Programmatic,
1509 0,
1510 );
1511
1512 let input = HookInput::SessionStart {
1513 session_id: SessionId::new(),
1514 };
1515 let matched = reg.matching(&input);
1516 assert_eq!(matched[0].hook.name(), "high");
1517 assert_eq!(matched[1].hook.name(), "mid");
1518 assert_eq!(matched[2].hook.name(), "low");
1519 }
1520
1521 #[test]
1522 fn test_registry_has_hooks_for() {
1523 let mut reg = HookRegistry::new();
1524 reg.register(
1525 Arc::new(PassthroughHook::new("pre_only").with_events(vec![HookEvent::PreToolUse])),
1526 HookSource::Programmatic,
1527 0,
1528 );
1529
1530 assert!(reg.has_hooks_for(HookEvent::PreToolUse));
1531 assert!(!reg.has_hooks_for(HookEvent::PostToolUse));
1532 }
1533
1534 #[test]
1535 fn test_registry_has_hooks_for_all_events() {
1536 let mut reg = HookRegistry::new();
1537 reg.register(
1539 Arc::new(PassthroughHook::new("global")),
1540 HookSource::Programmatic,
1541 0,
1542 );
1543
1544 for event in ALL_HOOK_EVENTS {
1545 assert!(reg.has_hooks_for(*event));
1546 }
1547 }
1548
1549 #[test]
1552 fn test_hook_source_serde_roundtrip() {
1553 for source in [
1554 HookSource::Settings,
1555 HookSource::Project,
1556 HookSource::Plugin {
1557 name: "linter".into(),
1558 },
1559 HookSource::Programmatic,
1560 HookSource::Session,
1561 ] {
1562 let json_str = serde_json::to_string(&source).unwrap();
1563 let restored: HookSource = serde_json::from_str(&json_str).unwrap();
1564 assert_eq!(source, restored);
1565 }
1566 }
1567
1568 #[tokio::test]
1571 async fn test_hook_trait_async_execution() {
1572 struct DenyBashHook;
1573
1574 #[async_trait]
1575 impl Hook for DenyBashHook {
1576 fn name(&self) -> &str {
1577 "deny_bash"
1578 }
1579
1580 fn events(&self) -> &[HookEvent] {
1581 &[HookEvent::PreToolUse]
1582 }
1583
1584 fn matcher(&self) -> Option<&str> {
1585 Some("bash")
1586 }
1587
1588 async fn on_event(&self, input: &HookInput) -> HookOutput {
1589 if let HookInput::PreToolUse { tool_input, .. } = input {
1590 let cmd = tool_input["command"].as_str().unwrap_or("");
1591 if cmd.contains("rm -rf") {
1592 return HookOutput::deny("dangerous command");
1593 }
1594 }
1595 HookOutput::passthrough()
1596 }
1597 }
1598
1599 let hook: Arc<dyn Hook> = Arc::new(DenyBashHook);
1600
1601 let safe_input = HookInput::PreToolUse {
1603 tool_name: "bash".into(),
1604 tool_input: json!({"command": "ls -la"}),
1605 call_id: ToolCallId::new("c1"),
1606 };
1607 let output = hook.on_event(&safe_input).await;
1608 assert!(!output.has_decision());
1609
1610 let dangerous_input = HookInput::PreToolUse {
1612 tool_name: "bash".into(),
1613 tool_input: json!({"command": "rm -rf /"}),
1614 call_id: ToolCallId::new("c2"),
1615 };
1616 let output = hook.on_event(&dangerous_input).await;
1617 assert!(output.permission.as_ref().unwrap().is_deny());
1618 }
1619
1620 #[tokio::test]
1621 async fn test_hook_trait_dyn_dispatch() {
1622 let hook: Arc<dyn Hook> = Arc::new(PassthroughHook::new("test"));
1623 assert_eq!(hook.name(), "test");
1624
1625 let input = HookInput::SessionStart {
1626 session_id: SessionId::new(),
1627 };
1628 let output = hook.on_event(&input).await;
1629 assert!(!output.has_decision());
1630 }
1631}