1use regex::Regex;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum HookEvent {
11 SessionStart,
13 UserPromptSubmit,
15 PreToolUse,
17 PostToolUse,
19 PostToolUseFailure,
21 SubagentStart,
23 SubagentStop,
25 Stop,
27 PreCompact,
29 SessionEnd,
31}
32
33impl HookEvent {
34 pub fn as_str(&self) -> &'static str {
36 match self {
37 Self::SessionStart => "SessionStart",
38 Self::UserPromptSubmit => "UserPromptSubmit",
39 Self::PreToolUse => "PreToolUse",
40 Self::PostToolUse => "PostToolUse",
41 Self::PostToolUseFailure => "PostToolUseFailure",
42 Self::SubagentStart => "SubagentStart",
43 Self::SubagentStop => "SubagentStop",
44 Self::Stop => "Stop",
45 Self::PreCompact => "PreCompact",
46 Self::SessionEnd => "SessionEnd",
47 }
48 }
49
50 pub fn from_config_str(s: &str) -> Option<Self> {
52 match s {
53 "SessionStart" => Some(Self::SessionStart),
54 "UserPromptSubmit" => Some(Self::UserPromptSubmit),
55 "PreToolUse" => Some(Self::PreToolUse),
56 "PostToolUse" => Some(Self::PostToolUse),
57 "PostToolUseFailure" => Some(Self::PostToolUseFailure),
58 "SubagentStart" => Some(Self::SubagentStart),
59 "SubagentStop" => Some(Self::SubagentStop),
60 "Stop" => Some(Self::Stop),
61 "PreCompact" => Some(Self::PreCompact),
62 "SessionEnd" => Some(Self::SessionEnd),
63 _ => None,
64 }
65 }
66
67 pub fn is_tool_event(&self) -> bool {
69 matches!(
70 self,
71 Self::PreToolUse | Self::PostToolUse | Self::PostToolUseFailure
72 )
73 }
74
75 pub fn is_subagent_event(&self) -> bool {
77 matches!(self, Self::SubagentStart | Self::SubagentStop)
78 }
79
80 pub const ALL: &'static [HookEvent] = &[
82 Self::SessionStart,
83 Self::UserPromptSubmit,
84 Self::PreToolUse,
85 Self::PostToolUse,
86 Self::PostToolUseFailure,
87 Self::SubagentStart,
88 Self::SubagentStop,
89 Self::Stop,
90 Self::PreCompact,
91 Self::SessionEnd,
92 ];
93}
94
95impl fmt::Display for HookEvent {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 f.write_str(self.as_str())
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct HookCommand {
104 #[serde(default = "default_command_type")]
106 pub r#type: String,
107
108 pub command: String,
110
111 #[serde(default = "default_timeout")]
113 pub timeout: u32,
114}
115
116fn default_command_type() -> String {
117 "command".to_string()
118}
119
120fn default_timeout() -> u32 {
121 60
122}
123
124impl HookCommand {
125 pub fn new(command: impl Into<String>) -> Self {
127 Self {
128 r#type: "command".to_string(),
129 command: command.into(),
130 timeout: 60,
131 }
132 }
133
134 pub fn with_timeout(command: impl Into<String>, timeout: u32) -> Self {
136 Self {
137 r#type: "command".to_string(),
138 command: command.into(),
139 timeout: timeout.clamp(1, 600),
140 }
141 }
142
143 pub fn effective_timeout(&self) -> u32 {
145 self.timeout.clamp(1, 600)
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct HookMatcher {
156 pub matcher: Option<String>,
158
159 pub hooks: Vec<HookCommand>,
161
162 #[serde(skip)]
164 compiled_regex: Option<CompiledRegex>,
165}
166
167#[derive(Clone)]
169struct CompiledRegex(Regex);
170
171impl fmt::Debug for CompiledRegex {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 write!(f, "Regex({})", self.0.as_str())
174 }
175}
176
177impl HookMatcher {
178 pub fn catch_all(hooks: Vec<HookCommand>) -> Self {
180 Self {
181 matcher: None,
182 hooks,
183 compiled_regex: None,
184 }
185 }
186
187 pub fn with_pattern(pattern: impl Into<String>, hooks: Vec<HookCommand>) -> Self {
189 let pattern = pattern.into();
190 let compiled = Regex::new(&pattern).ok().map(CompiledRegex);
191 Self {
192 matcher: Some(pattern),
193 hooks,
194 compiled_regex: compiled,
195 }
196 }
197
198 pub fn compile(&mut self) {
202 if let Some(ref pattern) = self.matcher {
203 self.compiled_regex = Regex::new(pattern).ok().map(CompiledRegex);
204 }
205 }
206
207 pub fn matches(&self, value: Option<&str>) -> bool {
213 let pattern = match &self.matcher {
214 None => return true,
215 Some(p) => p,
216 };
217
218 let value = match value {
219 None => return true,
220 Some(v) => v,
221 };
222
223 match &self.compiled_regex {
224 Some(compiled) => compiled.0.is_match(value),
225 None => pattern == value,
226 }
227 }
228}
229
230#[derive(Debug, Clone, Default, Serialize, Deserialize)]
235pub struct HookConfig {
236 #[serde(default)]
238 pub hooks: HashMap<String, Vec<HookMatcher>>,
239}
240
241impl HookConfig {
242 pub fn empty() -> Self {
244 Self {
245 hooks: HashMap::new(),
246 }
247 }
248
249 pub fn compile_all(&mut self) {
253 for matchers in self.hooks.values_mut() {
254 for matcher in matchers.iter_mut() {
255 matcher.compile();
256 }
257 }
258 }
259
260 pub fn get_matchers(&self, event: HookEvent) -> &[HookMatcher] {
262 self.hooks
263 .get(event.as_str())
264 .map(|v| v.as_slice())
265 .unwrap_or(&[])
266 }
267
268 pub fn has_hooks_for(&self, event: HookEvent) -> bool {
270 self.hooks
271 .get(event.as_str())
272 .map(|v| !v.is_empty())
273 .unwrap_or(false)
274 }
275
276 pub fn strip_unknown_events(&mut self) {
279 let valid: std::collections::HashSet<&str> =
280 HookEvent::ALL.iter().map(|e| e.as_str()).collect();
281 self.hooks.retain(|key, _| valid.contains(key.as_str()));
282 }
283
284 pub fn add_matcher(&mut self, event: HookEvent, matcher: HookMatcher) {
286 self.hooks
287 .entry(event.as_str().to_string())
288 .or_default()
289 .push(matcher);
290 }
291}
292
293#[cfg(test)]
294#[path = "models_tests.rs"]
295mod tests;