1use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::process::Command;
5use tokio::time::{timeout, Duration};
6
7pub const HOOK_EVENTS: &[&str] = &[
9 "PreToolUse",
10 "PostToolUse",
11 "PostToolUseFailure",
12 "Notification",
13 "UserPromptSubmit",
14 "SessionStart",
15 "SessionEnd",
16 "Stop",
17 "StopFailure",
18 "SubagentStart",
19 "SubagentStop",
20 "PreCompact",
21 "PostCompact",
22 "PermissionRequest",
23 "PermissionDenied",
24 "Setup",
25 "TeammateIdle",
26 "TaskCreated",
27 "TaskCompleted",
28 "Elicitation",
29 "ElicitationResult",
30 "ConfigChange",
31 "WorktreeCreate",
32 "WorktreeRemove",
33 "InstructionsLoaded",
34 "CwdChanged",
35 "FileChanged",
36];
37
38pub const EXIT_REASONS: &[&str] = &[
40 "clear",
41 "resume",
42 "logout",
43 "prompt_input_exit",
44 "other",
45 "bypass_permissions_disabled",
46];
47
48pub const INSTRUCTIONS_LOAD_REASONS: &[&str] = &[
50 "session_start",
51 "nested_traversal",
52 "path_glob_match",
53 "include",
54 "compact",
55];
56
57pub const INSTRUCTIONS_MEMORY_TYPES: &[&str] = &["User", "Project", "Local", "Managed"];
59
60pub const CONFIG_CHANGE_SOURCES: &[&str] = &[
62 "user_settings",
63 "project_settings",
64 "local_settings",
65 "policy_settings",
66 "skills",
67];
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub enum HookEvent {
72 PreToolUse,
73 PostToolUse,
74 PostToolUseFailure,
75 Notification,
76 UserPromptSubmit,
77 SessionStart,
78 SessionEnd,
79 Stop,
80 StopFailure,
81 SubagentStart,
82 SubagentStop,
83 PreCompact,
84 PostCompact,
85 PermissionRequest,
86 PermissionDenied,
87 Setup,
88 TeammateIdle,
89 TaskCreated,
90 TaskCompleted,
91 Elicitation,
92 ElicitationResult,
93 ConfigChange,
94 WorktreeCreate,
95 WorktreeRemove,
96 InstructionsLoaded,
97 CwdChanged,
98 FileChanged,
99}
100
101impl HookEvent {
102 pub fn as_str(&self) -> &'static str {
103 match self {
104 HookEvent::PreToolUse => "PreToolUse",
105 HookEvent::PostToolUse => "PostToolUse",
106 HookEvent::PostToolUseFailure => "PostToolUseFailure",
107 HookEvent::Notification => "Notification",
108 HookEvent::UserPromptSubmit => "UserPromptSubmit",
109 HookEvent::SessionStart => "SessionStart",
110 HookEvent::SessionEnd => "SessionEnd",
111 HookEvent::Stop => "Stop",
112 HookEvent::StopFailure => "StopFailure",
113 HookEvent::SubagentStart => "SubagentStart",
114 HookEvent::SubagentStop => "SubagentStop",
115 HookEvent::PreCompact => "PreCompact",
116 HookEvent::PostCompact => "PostCompact",
117 HookEvent::PermissionRequest => "PermissionRequest",
118 HookEvent::PermissionDenied => "PermissionDenied",
119 HookEvent::Setup => "Setup",
120 HookEvent::TeammateIdle => "TeammateIdle",
121 HookEvent::TaskCreated => "TaskCreated",
122 HookEvent::TaskCompleted => "TaskCompleted",
123 HookEvent::Elicitation => "Elicitation",
124 HookEvent::ElicitationResult => "ElicitationResult",
125 HookEvent::ConfigChange => "ConfigChange",
126 HookEvent::WorktreeCreate => "WorktreeCreate",
127 HookEvent::WorktreeRemove => "WorktreeRemove",
128 HookEvent::InstructionsLoaded => "InstructionsLoaded",
129 HookEvent::CwdChanged => "CwdChanged",
130 HookEvent::FileChanged => "FileChanged",
131 }
132 }
133
134 pub fn from_str(s: &str) -> Option<Self> {
135 match s {
136 "PreToolUse" => Some(HookEvent::PreToolUse),
137 "PostToolUse" => Some(HookEvent::PostToolUse),
138 "PostToolUseFailure" => Some(HookEvent::PostToolUseFailure),
139 "Notification" => Some(HookEvent::Notification),
140 "UserPromptSubmit" => Some(HookEvent::UserPromptSubmit),
141 "SessionStart" => Some(HookEvent::SessionStart),
142 "SessionEnd" => Some(HookEvent::SessionEnd),
143 "Stop" => Some(HookEvent::Stop),
144 "StopFailure" => Some(HookEvent::StopFailure),
145 "SubagentStart" => Some(HookEvent::SubagentStart),
146 "SubagentStop" => Some(HookEvent::SubagentStop),
147 "PreCompact" => Some(HookEvent::PreCompact),
148 "PostCompact" => Some(HookEvent::PostCompact),
149 "PermissionRequest" => Some(HookEvent::PermissionRequest),
150 "PermissionDenied" => Some(HookEvent::PermissionDenied),
151 "Setup" => Some(HookEvent::Setup),
152 "TeammateIdle" => Some(HookEvent::TeammateIdle),
153 "TaskCreated" => Some(HookEvent::TaskCreated),
154 "TaskCompleted" => Some(HookEvent::TaskCompleted),
155 "Elicitation" => Some(HookEvent::Elicitation),
156 "ElicitationResult" => Some(HookEvent::ElicitationResult),
157 "ConfigChange" => Some(HookEvent::ConfigChange),
158 "WorktreeCreate" => Some(HookEvent::WorktreeCreate),
159 "WorktreeRemove" => Some(HookEvent::WorktreeRemove),
160 "InstructionsLoaded" => Some(HookEvent::InstructionsLoaded),
161 "CwdChanged" => Some(HookEvent::CwdChanged),
162 "FileChanged" => Some(HookEvent::FileChanged),
163 _ => None,
164 }
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct HookDefinition {
171 pub command: Option<String>,
173 pub timeout: Option<u64>,
175 pub matcher: Option<String>,
177}
178
179impl<'de> Deserialize<'de> for HookDefinition {
180 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181 where
182 D: serde::Deserializer<'de>,
183 {
184 #[derive(Deserialize)]
185 #[serde(rename_all = "camelCase")]
186 struct HookDef {
187 command: Option<String>,
188 timeout: Option<u64>,
189 matcher: Option<String>,
190 }
191
192 let def = HookDef::deserialize(deserializer)?;
193 Ok(HookDefinition {
194 command: def.command,
195 timeout: def.timeout.or(Some(30000)),
196 matcher: def.matcher,
197 })
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203#[serde(rename_all = "camelCase")]
204pub struct HookInput {
205 pub event: String,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 pub tool_name: Option<String>,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 pub tool_input: Option<serde_json::Value>,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 pub tool_output: Option<serde_json::Value>,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 pub tool_use_id: Option<String>,
214 #[serde(skip_serializing_if = "Option::is_none")]
215 pub session_id: Option<String>,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub cwd: Option<String>,
218 #[serde(skip_serializing_if = "Option::is_none")]
219 pub error: Option<String>,
220}
221
222impl HookInput {
223 pub fn new(event: &str) -> Self {
224 Self {
225 event: event.to_string(),
226 tool_name: None,
227 tool_input: None,
228 tool_output: None,
229 tool_use_id: None,
230 session_id: None,
231 cwd: None,
232 error: None,
233 }
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239#[serde(rename_all = "camelCase")]
240pub struct HookOutput {
241 #[serde(skip_serializing_if = "Option::is_none")]
242 pub message: Option<String>,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 pub permission_update: Option<PermissionUpdate>,
245 #[serde(skip_serializing_if = "Option::is_none")]
246 pub block: Option<bool>,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub notification: Option<Notification>,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct PermissionUpdate {
254 pub tool: String,
255 pub behavior: String,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
259#[serde(rename_all = "camelCase")]
260pub struct Notification {
261 pub title: String,
262 pub body: String,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 pub level: Option<String>,
265}
266
267pub type HookConfig = HashMap<String, Vec<HookDefinition>>;
269
270#[derive(Debug, Default, Clone)]
272pub struct HookRegistry {
273 hooks: HashMap<String, Vec<HookDefinition>>,
274}
275
276impl HookRegistry {
277 pub fn new() -> Self {
279 Self {
280 hooks: HashMap::new(),
281 }
282 }
283
284 pub fn register_from_config(&mut self, config: HookConfig) {
286 for (event, definitions) in config {
287 if !HOOK_EVENTS.contains(&event.as_str()) {
288 continue;
289 }
290 let existing = self.hooks.entry(event).or_insert_with(Vec::new);
291 existing.extend(definitions);
292 }
293 }
294
295 pub fn register(&mut self, event: &str, definition: HookDefinition) {
297 if !HOOK_EVENTS.contains(&event) {
298 return;
299 }
300 let existing = self.hooks.entry(event.to_string()).or_insert_with(Vec::new);
301 existing.push(definition);
302 }
303
304 pub async fn execute(&self, event: &str, mut input: HookInput) -> Vec<HookOutput> {
306 let definitions = match self.hooks.get(event) {
307 Some(d) => d,
308 None => return vec![],
309 };
310
311 input.event = event.to_string();
312 let mut results = Vec::new();
313
314 for def in definitions {
315 if let Some(matcher) = &def.matcher {
317 if let Some(tool_name) = &input.tool_name {
318 if let Ok(re) = regex::Regex::new(matcher) {
319 if !re.is_match(tool_name) {
320 continue;
321 }
322 }
323 }
324 }
325
326 if let Some(command) = &def.command {
327 match execute_shell_hook(command, &input, def.timeout.unwrap_or(30000)).await {
328 Ok(output) => {
329 if let Some(o) = output {
330 results.push(o);
331 }
332 }
333 Err(e) => {
334 eprintln!("[Hook] {} hook failed: {}", event, e);
335 }
336 }
337 }
338 }
341
342 results
343 }
344
345 pub fn has_hooks(&self, event: &str) -> bool {
347 self.hooks
348 .get(event)
349 .map(|h| !h.is_empty())
350 .unwrap_or(false)
351 }
352
353 pub fn clear(&mut self) {
355 self.hooks.clear();
356 }
357}
358
359async fn execute_shell_hook(
361 command: &str,
362 input: &HookInput,
363 timeout_ms: u64,
364) -> Result<Option<HookOutput>, crate::error::AgentError> {
365 let input_json = serde_json::to_string(input).map_err(crate::error::AgentError::Json)?;
366
367 let cmd_str = command.to_string();
369 let event = input.event.clone();
370 let tool_name = input.tool_name.clone();
371 let session_id = input.session_id.clone();
372 let cwd = input.cwd.clone();
373
374 let result = timeout(
375 Duration::from_millis(timeout_ms),
376 tokio::task::spawn_blocking(move || {
377 let mut cmd = Command::new("bash");
378 cmd.args(["-c", &cmd_str])
379 .env("HOOK_EVENT", &event)
380 .env("HOOK_TOOL_NAME", tool_name.as_deref().unwrap_or(""))
381 .env("HOOK_SESSION_ID", session_id.as_deref().unwrap_or(""))
382 .env("HOOK_CWD", cwd.as_deref().unwrap_or(""))
383 .stdin(std::process::Stdio::piped())
384 .stdout(std::process::Stdio::piped())
385 .stderr(std::process::Stdio::piped());
386
387 let mut child = cmd.spawn()?;
388
389 use std::io::Write;
390 if let Some(mut stdin) = child.stdin.take() {
391 stdin.write_all(input_json.as_bytes())?;
392 }
393
394 let output = child.wait_with_output()?;
395
396 if !output.status.success() {
397 return Ok(None);
398 }
399
400 let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
401 if stdout.is_empty() {
402 return Ok(None);
403 }
404
405 if let Ok(hook_output) = serde_json::from_str::<HookOutput>(&stdout) {
407 Ok(Some(hook_output))
408 } else {
409 Ok(Some(HookOutput {
411 message: Some(stdout),
412 permission_update: None,
413 block: None,
414 notification: None,
415 }))
416 }
417 }),
418 )
419 .await;
420
421 match result {
422 Ok(Ok(r)) => r,
423 Ok(Err(e)) => {
424 let err = std::io::Error::new(std::io::ErrorKind::Other, e.to_string());
425 Err(crate::error::AgentError::Io(err))
426 }
427 Err(_) => {
428 let err = std::io::Error::new(std::io::ErrorKind::TimedOut, "Hook timeout");
429 Err(crate::error::AgentError::Io(err))
430 }
431 }
432}
433
434pub fn create_hook_registry(config: Option<HookConfig>) -> HookRegistry {
436 let mut registry = HookRegistry::new();
437 if let Some(c) = config {
438 registry.register_from_config(c);
439 }
440 registry
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_hook_event_as_str() {
449 assert_eq!(HookEvent::PreToolUse.as_str(), "PreToolUse");
450 assert_eq!(HookEvent::PostToolUse.as_str(), "PostToolUse");
451 assert_eq!(HookEvent::SessionStart.as_str(), "SessionStart");
452 }
453
454 #[test]
455 fn test_hook_event_from_str() {
456 assert_eq!(
457 HookEvent::from_str("PreToolUse"),
458 Some(HookEvent::PreToolUse)
459 );
460 assert_eq!(HookEvent::from_str("Invalid"), None);
461 }
462
463 #[test]
464 fn test_hook_events_constant() {
465 assert!(HOOK_EVENTS.contains(&"PreToolUse"));
466 assert!(HOOK_EVENTS.contains(&"PostToolUse"));
467 assert!(HOOK_EVENTS.contains(&"SessionStart"));
468 }
469
470 #[test]
471 fn test_hook_registry_new() {
472 let registry = HookRegistry::new();
473 assert!(!registry.has_hooks("PreToolUse"));
474 }
475
476 #[test]
477 fn test_hook_registry_register() {
478 let mut registry = HookRegistry::new();
479 registry.register(
480 "PreToolUse",
481 HookDefinition {
482 command: Some("echo test".to_string()),
483 timeout: Some(5000),
484 matcher: Some("Read.*".to_string()),
485 },
486 );
487 assert!(registry.has_hooks("PreToolUse"));
488 }
489
490 #[test]
491 fn test_hook_registry_clear() {
492 let mut registry = HookRegistry::new();
493 registry.register(
494 "PreToolUse",
495 HookDefinition {
496 command: Some("echo test".to_string()),
497 timeout: None,
498 matcher: None,
499 },
500 );
501 registry.clear();
502 assert!(!registry.has_hooks("PreToolUse"));
503 }
504
505 #[test]
506 fn test_hook_input_new() {
507 let input = HookInput::new("PreToolUse");
508 assert_eq!(input.event, "PreToolUse");
509 }
510
511 #[test]
512 fn test_hook_output_serialization() {
513 let output = HookOutput {
514 message: Some("test message".to_string()),
515 permission_update: None,
516 block: Some(true),
517 notification: None,
518 };
519 let json = serde_json::to_string(&output).unwrap();
520 assert!(json.contains("test message"));
521 }
522
523 #[test]
524 fn test_create_hook_registry() {
525 let registry = create_hook_registry(None);
526 assert!(!registry.has_hooks("PreToolUse"));
527 }
528
529 #[tokio::test]
530 async fn test_execute_no_hooks() {
531 let registry = HookRegistry::new();
532 let input = HookInput::new("PreToolUse");
533 let results = registry.execute("PreToolUse", input).await;
534 assert!(results.is_empty());
535 }
536
537 #[tokio::test]
538 async fn test_execute_with_invalid_event() {
539 let registry = HookRegistry::new();
540 let input = HookInput::new("InvalidEvent");
541 let results = registry.execute("InvalidEvent", input).await;
542 assert!(results.is_empty());
543 }
544}