1use std::path::Path;
2use std::sync::Arc;
3
4use glob::Pattern;
5use imp_llm::{AssistantMessage, ContentBlock, Message, ToolResultMessage};
6use serde::{Deserialize, Serialize};
7use tokio::process::Command;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum HookBackgroundEvent {
12 NonBlockingHookFailed {
13 event: String,
14 command: String,
15 error: String,
16 },
17 NonBlockingHookPanicked {
18 event: String,
19 command: String,
20 error: String,
21 },
22}
23
24impl std::fmt::Display for HookBackgroundEvent {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 Self::NonBlockingHookFailed {
28 event,
29 command,
30 error,
31 } => write!(
32 f,
33 "Non-blocking hook failed for event '{event}' while running `{command}`: {error}"
34 ),
35 Self::NonBlockingHookPanicked {
36 event,
37 command,
38 error,
39 } => write!(
40 f,
41 "Non-blocking hook panicked for event '{event}' while running `{command}`: {error}"
42 ),
43 }
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct HookDef {
50 pub event: String,
51 #[serde(rename = "match")]
52 pub match_pattern: Option<String>,
53 pub action: String,
54 pub command: Option<String>,
55 #[serde(default)]
56 pub blocking: bool,
57 pub threshold: Option<f64>,
58}
59
60#[derive(Clone)]
62pub enum HookAction {
63 Shell { command: String },
65 Callback(Arc<dyn Fn(&HookEvent<'_>) -> HookResult + Send + Sync>),
67}
68
69impl std::fmt::Debug for HookAction {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 HookAction::Shell { command } => {
73 f.debug_struct("Shell").field("command", command).finish()
74 }
75 HookAction::Callback(_) => f.write_str("Callback(...)"),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct HookDefinition {
83 pub event: String,
84 pub match_pattern: Option<String>,
85 pub action: HookAction,
86 pub blocking: bool,
87 pub threshold: Option<f64>,
88}
89
90#[derive(Clone)]
92pub enum HookEvent<'a> {
93 AfterFileWrite {
94 file: &'a Path,
95 },
96 BeforeToolCall {
97 tool_name: &'a str,
98 args: &'a serde_json::Value,
99 },
100 AfterToolCall {
101 tool_name: &'a str,
102 result: &'a ToolResultMessage,
103 },
104 BeforeLlmCall,
105 OnContextThreshold {
106 ratio: f64,
107 },
108 OnSessionStart,
109 OnSessionShutdown,
110 OnAgentStart {
111 prompt: &'a str,
112 },
113 OnAgentEnd {
114 messages: &'a [Message],
115 },
116 OnTurnEnd {
117 index: u32,
118 message: &'a AssistantMessage,
119 },
120}
121
122impl<'a> HookEvent<'a> {
123 fn event_name(&self) -> &'static str {
125 match self {
126 HookEvent::AfterFileWrite { .. } => "after_file_write",
127 HookEvent::BeforeToolCall { .. } => "before_tool_call",
128 HookEvent::AfterToolCall { .. } => "after_tool_call",
129 HookEvent::BeforeLlmCall => "before_llm_call",
130 HookEvent::OnContextThreshold { .. } => "on_context_threshold",
131 HookEvent::OnSessionStart => "on_session_start",
132 HookEvent::OnSessionShutdown => "on_session_shutdown",
133 HookEvent::OnAgentStart { .. } => "on_agent_start",
134 HookEvent::OnAgentEnd { .. } => "on_agent_end",
135 HookEvent::OnTurnEnd { .. } => "on_turn_end",
136 }
137 }
138}
139
140#[derive(Default, Debug)]
142pub struct HookResult {
143 pub block: bool,
144 pub reason: Option<String>,
145 pub modified_content: Option<Vec<ContentBlock>>,
146}
147
148pub struct HookRunner {
150 toml_hooks: Vec<HookDefinition>,
152 programmatic_hooks: Vec<HookDefinition>,
154 background_reporter: Option<Arc<dyn Fn(HookBackgroundEvent) + Send + Sync>>,
156}
157
158impl HookRunner {
159 pub fn new() -> Self {
160 Self {
161 toml_hooks: Vec::new(),
162 programmatic_hooks: Vec::new(),
163 background_reporter: None,
164 }
165 }
166
167 pub fn add(&mut self, def: HookDef) {
169 if let Some(resolved) = resolve_hook_def(def) {
170 self.toml_hooks.push(resolved);
171 }
172 }
173
174 pub fn load_from_config(&mut self, defs: Vec<HookDef>) {
176 for def in defs {
177 self.add(def);
178 }
179 }
180
181 pub fn register(&mut self, hook: HookDefinition) {
183 self.programmatic_hooks.push(hook);
184 }
185
186 pub fn len(&self) -> usize {
188 self.toml_hooks.len() + self.programmatic_hooks.len()
189 }
190
191 pub fn is_empty(&self) -> bool {
193 self.toml_hooks.is_empty() && self.programmatic_hooks.is_empty()
194 }
195
196 pub fn set_background_reporter(
198 &mut self,
199 reporter: Arc<dyn Fn(HookBackgroundEvent) + Send + Sync>,
200 ) {
201 self.background_reporter = Some(reporter);
202 }
203
204 pub fn register_callback(
206 &mut self,
207 event: &str,
208 callback: Arc<dyn Fn(&HookEvent<'_>) -> HookResult + Send + Sync>,
209 ) {
210 self.programmatic_hooks.push(HookDefinition {
211 event: event.to_string(),
212 match_pattern: None,
213 action: HookAction::Callback(callback),
214 blocking: true,
215 threshold: None,
216 });
217 }
218
219 pub async fn fire(&self, event: &HookEvent<'_>) -> Vec<HookResult> {
225 let mut results = Vec::new();
226
227 let all_hooks = self.toml_hooks.iter().chain(self.programmatic_hooks.iter());
229
230 for hook in all_hooks {
231 if !matches_event(hook, event) {
232 continue;
233 }
234
235 if hook.blocking {
236 let result = execute_hook(hook, event).await;
237 results.push(result);
238 } else {
239 if let HookAction::Shell { command } = &hook.action {
241 let cmd = interpolate_command(command, event);
242 run_non_blocking_shell_hook(
243 hook_event_label(event),
244 cmd,
245 self.background_reporter.clone(),
246 );
247 }
248 }
250 }
251
252 results
253 }
254}
255
256impl Default for HookRunner {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262fn hook_event_label(event: &HookEvent<'_>) -> String {
263 event.event_name().to_string()
264}
265
266fn report_non_blocking_hook_outcome(
267 join_result: Result<std::io::Result<std::process::Output>, tokio::task::JoinError>,
268 event_name: String,
269 command_for_report: String,
270 reporter: Arc<dyn Fn(HookBackgroundEvent) + Send + Sync>,
271) {
272 match join_result {
273 Ok(Ok(output)) => {
274 if !output.status.success() {
275 let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
276 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
277 let error = if !stderr.is_empty() {
278 stderr
279 } else if !stdout.is_empty() {
280 stdout
281 } else {
282 format!(
283 "command exited with status {}",
284 output
285 .status
286 .code()
287 .map(|code| code.to_string())
288 .unwrap_or_else(|| "terminated by signal".into())
289 )
290 };
291 reporter(HookBackgroundEvent::NonBlockingHookFailed {
292 event: event_name,
293 command: command_for_report,
294 error,
295 });
296 }
297 }
298 Ok(Err(error)) => reporter(HookBackgroundEvent::NonBlockingHookFailed {
299 event: event_name,
300 command: command_for_report,
301 error: error.to_string(),
302 }),
303 Err(join_error) => reporter(HookBackgroundEvent::NonBlockingHookPanicked {
304 event: event_name,
305 command: command_for_report,
306 error: join_error.to_string(),
307 }),
308 }
309}
310
311fn run_non_blocking_shell_hook(
312 event_name: String,
313 command: String,
314 reporter: Option<Arc<dyn Fn(HookBackgroundEvent) + Send + Sync>>,
315) {
316 tokio::spawn(async move {
317 let command_for_run = command.clone();
318 let command_for_report = command;
319 let join_result = tokio::spawn(async move {
320 Command::new("sh")
321 .arg("-c")
322 .arg(&command_for_run)
323 .stdin(std::process::Stdio::null())
324 .output()
325 .await
326 })
327 .await;
328
329 if let Some(reporter) = reporter {
330 report_non_blocking_hook_outcome(join_result, event_name, command_for_report, reporter);
331 }
332 });
333}
334
335fn resolve_hook_def(def: HookDef) -> Option<HookDefinition> {
336 let action = match def.action.as_str() {
337 "shell" => {
338 let command = def.command?;
339 HookAction::Shell { command }
340 }
341 _ => return None,
342 };
343
344 Some(HookDefinition {
345 event: def.event,
346 match_pattern: def.match_pattern,
347 action,
348 blocking: def.blocking,
349 threshold: def.threshold,
350 })
351}
352
353fn matches_event(hook: &HookDefinition, event: &HookEvent<'_>) -> bool {
355 if hook.event != event.event_name() {
357 return false;
358 }
359
360 if let Some(pattern) = &hook.match_pattern {
362 match event {
363 HookEvent::AfterFileWrite { file } => {
364 let file_str = file.to_string_lossy();
365 if let Ok(glob) = Pattern::new(pattern) {
367 let file_name = file
368 .file_name()
369 .map(|n| n.to_string_lossy().to_string())
370 .unwrap_or_default();
371 if !glob.matches(&file_str) && !glob.matches(&file_name) {
372 return false;
373 }
374 } else {
375 return false;
376 }
377 }
378 HookEvent::BeforeToolCall { tool_name, .. }
379 | HookEvent::AfterToolCall { tool_name, .. } => {
380 if pattern != *tool_name {
381 if let Ok(glob) = Pattern::new(pattern) {
383 if !glob.matches(tool_name) {
384 return false;
385 }
386 } else {
387 return false;
388 }
389 }
390 }
391 _ => {
392 }
394 }
395 }
396
397 if let HookEvent::OnContextThreshold { ratio } = event {
399 if let Some(threshold) = hook.threshold {
400 if *ratio < threshold {
401 return false;
402 }
403 }
404 }
405
406 true
407}
408
409fn interpolate_command(command: &str, event: &HookEvent<'_>) -> String {
411 let mut result = command.to_string();
412
413 match event {
414 HookEvent::AfterFileWrite { file } => {
415 result = replace_placeholder(&result, "file", &file.to_string_lossy());
416 }
417 HookEvent::BeforeToolCall { tool_name, .. } => {
418 result = replace_placeholder(&result, "tool_name", tool_name);
419 }
420 HookEvent::AfterToolCall {
421 tool_name,
422 result: tool_result,
423 } => {
424 result = replace_placeholder(&result, "tool_name", tool_name);
425 result = replace_placeholder(
426 &result,
427 "is_error",
428 if tool_result.is_error {
429 "true"
430 } else {
431 "false"
432 },
433 );
434 let exit_code = tool_result
436 .details
437 .get("exit_code")
438 .and_then(|v| v.as_i64())
439 .map(|c| c.to_string())
440 .unwrap_or_default();
441 result = replace_placeholder(&result, "exit_code", &exit_code);
442 let output_first = tool_result
444 .content
445 .iter()
446 .filter_map(|b| match b {
447 imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
448 _ => None,
449 })
450 .next()
451 .and_then(|t| t.lines().next())
452 .unwrap_or("");
453 result = replace_placeholder(&result, "output_first_line", output_first);
454 let command = tool_result
456 .details
457 .get("command")
458 .and_then(|v| v.as_str())
459 .unwrap_or("");
460 result = replace_placeholder(&result, "command", command);
461 }
462 HookEvent::OnContextThreshold { ratio } => {
463 result = replace_placeholder(&result, "ratio", &ratio.to_string());
464 }
465 HookEvent::OnTurnEnd { index, .. } => {
466 result = replace_placeholder(&result, "index", &index.to_string());
467 }
468 _ => {}
469 }
470
471 result
472}
473
474fn replace_placeholder(template: &str, name: &str, value: &str) -> String {
475 let raw = format!("{{{name}}}");
476 let single_marker = format!("\u{0}__imp_hook_single_{name}__\u{0}");
477 let double_marker = format!("\u{0}__imp_hook_double_{name}__\u{0}");
478
479 let mut result = template.replace(&format!("'{raw}'"), &single_marker);
480 result = result.replace(&format!("\"{raw}\""), &double_marker);
481 result = result.replace(&raw, value);
482 result = result.replace(&single_marker, &shell_single_quote(value));
483 result = result.replace(&double_marker, &shell_double_quote(value));
484 result
485}
486
487fn shell_single_quote(value: &str) -> String {
488 format!("'{}'", value.replace('\'', "'\\''"))
489}
490
491fn shell_double_quote(value: &str) -> String {
492 let mut escaped = String::with_capacity(value.len());
493 for ch in value.chars() {
494 match ch {
495 '\\' | '"' | '$' | '`' => {
496 escaped.push('\\');
497 escaped.push(ch);
498 }
499 _ => escaped.push(ch),
500 }
501 }
502 format!("\"{escaped}\"")
503}
504
505async fn execute_hook(hook: &HookDefinition, event: &HookEvent<'_>) -> HookResult {
507 match &hook.action {
508 HookAction::Shell { command } => {
509 let cmd = interpolate_command(command, event);
510 match Command::new("sh")
511 .arg("-c")
512 .arg(&cmd)
513 .stdin(std::process::Stdio::null())
514 .output()
515 .await
516 {
517 Ok(output) => {
518 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
519 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
520
521 let block = matches!(event, HookEvent::BeforeToolCall { .. })
523 && !output.status.success();
524
525 let reason = if block {
526 Some(if stderr.is_empty() {
527 stdout.clone()
528 } else {
529 stderr
530 })
531 } else {
532 None
533 };
534
535 let modified_content = if matches!(event, HookEvent::AfterToolCall { .. })
537 && !stdout.trim().is_empty()
538 && output.status.success()
539 {
540 Some(vec![ContentBlock::Text {
541 text: stdout.trim().to_string(),
542 }])
543 } else {
544 None
545 };
546
547 HookResult {
548 block,
549 reason,
550 modified_content,
551 }
552 }
553 Err(e) => HookResult {
554 block: false,
555 reason: Some(format!("Hook command failed: {e}")),
556 modified_content: None,
557 },
558 }
559 }
560 HookAction::Callback(cb) => cb(event),
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use std::path::PathBuf;
568 use std::sync::Mutex;
569
570 #[test]
571 fn hook_def_toml_parsing() {
572 let toml_str = r#"
573[[hooks]]
574event = "after_file_write"
575match = "*.rs"
576action = "shell"
577command = "rustfmt {file}"
578blocking = true
579
580[[hooks]]
581event = "on_context_threshold"
582action = "shell"
583command = "echo threshold"
584threshold = 0.8
585"#;
586
587 #[derive(Deserialize)]
588 struct Wrapper {
589 hooks: Vec<HookDef>,
590 }
591
592 let parsed: Wrapper = toml::from_str(toml_str).expect("TOML parsing failed");
593 assert_eq!(parsed.hooks.len(), 2);
594
595 let h0 = &parsed.hooks[0];
596 assert_eq!(h0.event, "after_file_write");
597 assert_eq!(h0.match_pattern.as_deref(), Some("*.rs"));
598 assert_eq!(h0.action, "shell");
599 assert_eq!(h0.command.as_deref(), Some("rustfmt {file}"));
600 assert!(h0.blocking);
601 assert!(h0.threshold.is_none());
602
603 let h1 = &parsed.hooks[1];
604 assert_eq!(h1.event, "on_context_threshold");
605 assert!(h1.match_pattern.is_none());
606 assert_eq!(h1.threshold, Some(0.8));
607 }
608
609 #[test]
610 fn hook_interpolation_file() {
611 let event = HookEvent::AfterFileWrite {
612 file: Path::new("/tmp/test.rs"),
613 };
614 let result = interpolate_command("rustfmt {file}", &event);
615 assert_eq!(result, "rustfmt /tmp/test.rs");
616 }
617
618 #[test]
619 fn hook_interpolation_tool_name() {
620 let args = serde_json::json!({"path": "/tmp"});
621 let event = HookEvent::BeforeToolCall {
622 tool_name: "bash",
623 args: &args,
624 };
625 let result = interpolate_command("echo {tool_name}", &event);
626 assert_eq!(result, "echo bash");
627 }
628
629 #[test]
630 fn hook_interpolation_quoted_placeholder() {
631 let command_text = "pwd && egrep '(^|/)(README|VISION)\\.md$' && printf '$HOME'";
632 let result_msg = ToolResultMessage {
633 tool_call_id: "call_quoted".into(),
634 tool_name: "bash".into(),
635 content: vec![ContentBlock::Text { text: "ok".into() }],
636 is_error: true,
637 details: serde_json::json!({
638 "exit_code": 2,
639 "command": command_text,
640 }),
641 timestamp: 0,
642 };
643 let event = HookEvent::AfterToolCall {
644 tool_name: "bash",
645 result: &result_msg,
646 };
647
648 let interpolated = interpolate_command(
649 "hook '{is_error}' '{exit_code}' '{command}' \"{command}\" {command}",
650 &event,
651 );
652
653 assert_eq!(
654 interpolated,
655 format!(
656 "hook 'true' '2' {} {} {}",
657 shell_single_quote(command_text),
658 shell_double_quote(command_text),
659 command_text
660 )
661 );
662 }
663
664 #[test]
665 fn hook_interpolation_ratio() {
666 let event = HookEvent::OnContextThreshold { ratio: 0.75 };
667 let result = interpolate_command("echo ratio={ratio}", &event);
668 assert_eq!(result, "echo ratio=0.75");
669 }
670
671 #[test]
672 fn hook_event_name_mapping() {
673 let path = PathBuf::from("/tmp/test.rs");
674 assert_eq!(
675 HookEvent::AfterFileWrite { file: &path }.event_name(),
676 "after_file_write"
677 );
678 assert_eq!(HookEvent::BeforeLlmCall.event_name(), "before_llm_call");
679 assert_eq!(HookEvent::OnSessionStart.event_name(), "on_session_start");
680 assert_eq!(
681 HookEvent::OnSessionShutdown.event_name(),
682 "on_session_shutdown"
683 );
684 assert_eq!(
685 HookEvent::OnContextThreshold { ratio: 0.5 }.event_name(),
686 "on_context_threshold"
687 );
688 }
689
690 #[test]
691 fn hook_matches_event_name() {
692 let hook = HookDefinition {
693 event: "after_file_write".into(),
694 match_pattern: None,
695 action: HookAction::Shell {
696 command: "echo hi".into(),
697 },
698 blocking: false,
699 threshold: None,
700 };
701 let path = PathBuf::from("/tmp/test.rs");
702 let event = HookEvent::AfterFileWrite { file: &path };
703 assert!(matches_event(&hook, &event));
704
705 let wrong_event = HookEvent::BeforeLlmCall;
706 assert!(!matches_event(&hook, &wrong_event));
707 }
708
709 #[test]
710 fn hook_matches_file_glob() {
711 let hook = HookDefinition {
712 event: "after_file_write".into(),
713 match_pattern: Some("*.rs".into()),
714 action: HookAction::Shell {
715 command: "echo hi".into(),
716 },
717 blocking: false,
718 threshold: None,
719 };
720
721 let rs_path = PathBuf::from("/tmp/test.rs");
722 let rs_event = HookEvent::AfterFileWrite { file: &rs_path };
723 assert!(matches_event(&hook, &rs_event));
724
725 let py_path = PathBuf::from("/tmp/test.py");
726 let py_event = HookEvent::AfterFileWrite { file: &py_path };
727 assert!(!matches_event(&hook, &py_event));
728 }
729
730 #[test]
731 fn hook_matches_tool_name() {
732 let hook = HookDefinition {
733 event: "before_tool_call".into(),
734 match_pattern: Some("bash".into()),
735 action: HookAction::Shell {
736 command: "echo hi".into(),
737 },
738 blocking: true,
739 threshold: None,
740 };
741
742 let args = serde_json::json!({});
743 let match_event = HookEvent::BeforeToolCall {
744 tool_name: "bash",
745 args: &args,
746 };
747 assert!(matches_event(&hook, &match_event));
748
749 let no_match_event = HookEvent::BeforeToolCall {
750 tool_name: "read",
751 args: &args,
752 };
753 assert!(!matches_event(&hook, &no_match_event));
754 }
755
756 #[test]
757 fn hook_threshold_filtering() {
758 let hook = HookDefinition {
759 event: "on_context_threshold".into(),
760 match_pattern: None,
761 action: HookAction::Shell {
762 command: "echo hi".into(),
763 },
764 blocking: true,
765 threshold: Some(0.8),
766 };
767
768 let below = HookEvent::OnContextThreshold { ratio: 0.5 };
770 assert!(!matches_event(&hook, &below));
771
772 let at = HookEvent::OnContextThreshold { ratio: 0.8 };
774 assert!(matches_event(&hook, &at));
775
776 let above = HookEvent::OnContextThreshold { ratio: 0.95 };
778 assert!(matches_event(&hook, &above));
779 }
780
781 #[test]
782 fn hook_resolve_shell() {
783 let def = HookDef {
784 event: "after_file_write".into(),
785 match_pattern: Some("*.rs".into()),
786 action: "shell".into(),
787 command: Some("rustfmt {file}".into()),
788 blocking: true,
789 threshold: None,
790 };
791 let resolved = resolve_hook_def(def).expect("should resolve");
792 assert_eq!(resolved.event, "after_file_write");
793 assert!(resolved.blocking);
794 assert!(matches!(resolved.action, HookAction::Shell { .. }));
795 }
796
797 #[test]
798 fn hook_resolve_missing_command_returns_none() {
799 let def = HookDef {
800 event: "after_file_write".into(),
801 match_pattern: None,
802 action: "shell".into(),
803 command: None,
804 blocking: false,
805 threshold: None,
806 };
807 assert!(resolve_hook_def(def).is_none());
808 }
809
810 #[test]
811 fn hook_resolve_unknown_action_returns_none() {
812 let def = HookDef {
813 event: "after_file_write".into(),
814 match_pattern: None,
815 action: "unknown".into(),
816 command: Some("echo".into()),
817 blocking: false,
818 threshold: None,
819 };
820 assert!(resolve_hook_def(def).is_none());
821 }
822
823 #[tokio::test]
824 async fn hook_blocking_shell_executes() {
825 let mut runner = HookRunner::new();
826 runner.load_from_config(vec![HookDef {
827 event: "after_file_write".into(),
828 match_pattern: None,
829 action: "shell".into(),
830 command: Some("echo hello".into()),
831 blocking: true,
832 threshold: None,
833 }]);
834
835 let path = PathBuf::from("/tmp/test.txt");
836 let event = HookEvent::AfterFileWrite { file: &path };
837 let results = runner.fire(&event).await;
838 assert_eq!(results.len(), 1);
839 assert!(!results[0].block);
840 }
841
842 #[tokio::test]
843 async fn hook_non_blocking_fires_and_forgets() {
844 let mut runner = HookRunner::new();
845 runner.load_from_config(vec![HookDef {
846 event: "on_session_start".into(),
847 match_pattern: None,
848 action: "shell".into(),
849 command: Some("echo non-blocking".into()),
850 blocking: false,
851 threshold: None,
852 }]);
853
854 let event = HookEvent::OnSessionStart;
855 let started = std::time::Instant::now();
856 let results = runner.fire(&event).await;
857 assert!(results.is_empty());
859 assert!(started.elapsed() < std::time::Duration::from_secs(1));
860 }
861
862 #[tokio::test]
863 async fn hook_non_blocking_failure_is_reported() {
864 let mut runner = HookRunner::new();
865 let reported = Arc::new(Mutex::new(Vec::new()));
866 let reported_clone = Arc::clone(&reported);
867 runner.set_background_reporter(Arc::new(move |event| {
868 reported_clone.lock().unwrap().push(event);
869 }));
870 runner.load_from_config(vec![HookDef {
871 event: "on_session_start".into(),
872 match_pattern: None,
873 action: "shell".into(),
874 command: Some("exit 7".into()),
875 blocking: false,
876 threshold: None,
877 }]);
878
879 let event = HookEvent::OnSessionStart;
880 let results = runner.fire(&event).await;
881 assert!(results.is_empty());
882
883 for _ in 0..20 {
884 if !reported.lock().unwrap().is_empty() {
885 break;
886 }
887 tokio::time::sleep(std::time::Duration::from_millis(25)).await;
888 }
889
890 let reported = reported.lock().unwrap();
891 assert_eq!(reported.len(), 1);
892 match &reported[0] {
893 HookBackgroundEvent::NonBlockingHookFailed { event, command, .. } => {
894 assert_eq!(event, "on_session_start");
895 assert_eq!(command, "exit 7");
896 }
897 other => panic!("expected non-blocking hook failure, got {other:?}"),
898 }
899 }
900
901 #[tokio::test]
902 async fn hook_after_tool_call_nonblocking_quoted_command() {
903 let temp = tempfile::tempdir().unwrap();
904 let output_path = temp.path().join("hook-args.txt");
905 let script_path = temp.path().join("capture.sh");
906 std::fs::write(
907 &script_path,
908 format!(
909 "#!/bin/sh\nprintf '%s\\n%s\\n%s\\n' \"$1\" \"$2\" \"$3\" > {}\n",
910 output_path.display()
911 ),
912 )
913 .unwrap();
914 #[cfg(unix)]
915 {
916 use std::os::unix::fs::PermissionsExt;
917 let mut perms = std::fs::metadata(&script_path).unwrap().permissions();
918 perms.set_mode(0o755);
919 std::fs::set_permissions(&script_path, perms).unwrap();
920 }
921
922 let mut runner = HookRunner::new();
923 runner.load_from_config(vec![HookDef {
924 event: "after_tool_call".into(),
925 match_pattern: Some("bash".into()),
926 action: "shell".into(),
927 command: Some(format!(
928 "{} '{{is_error}}' '{{exit_code}}' '{{command}}'",
929 script_path.display()
930 )),
931 blocking: false,
932 threshold: None,
933 }]);
934
935 let original_command = "pwd && egrep '(^|/)(README|VISION)\\.md$' | sort && printf '$HOME'";
936 let result_msg = ToolResultMessage {
937 tool_call_id: "call_1".into(),
938 tool_name: "bash".into(),
939 content: vec![ContentBlock::Text {
940 text: "failed".into(),
941 }],
942 is_error: true,
943 details: serde_json::json!({
944 "exit_code": 2,
945 "command": original_command,
946 }),
947 timestamp: 0,
948 };
949 let event = HookEvent::AfterToolCall {
950 tool_name: "bash",
951 result: &result_msg,
952 };
953
954 let results = runner.fire(&event).await;
955 assert!(results.is_empty());
956
957 for _ in 0..40 {
958 if output_path.exists() {
959 break;
960 }
961 tokio::time::sleep(std::time::Duration::from_millis(25)).await;
962 }
963
964 let captured = std::fs::read_to_string(&output_path).unwrap();
965 let mut lines = captured.lines();
966 assert_eq!(lines.next(), Some("true"));
967 assert_eq!(lines.next(), Some("2"));
968 assert_eq!(lines.next(), Some(original_command));
969 }
970
971 #[test]
972 fn report_non_blocking_hook_outcome_maps_join_failure_to_panic_event() {
973 let reported = Arc::new(Mutex::new(Vec::new()));
974 let reported_clone = Arc::clone(&reported);
975 let reporter: Arc<dyn Fn(HookBackgroundEvent) + Send + Sync> = Arc::new(move |event| {
976 reported_clone.lock().unwrap().push(event);
977 });
978
979 let previous_hook = std::panic::take_hook();
980 std::panic::set_hook(Box::new(|_| {}));
981
982 let runtime = tokio::runtime::Runtime::new().unwrap();
983 let join_error = runtime.block_on(async {
984 tokio::spawn(async move {
985 panic!("intentional join failure for reporting test");
986 })
987 .await
988 .unwrap_err()
989 });
990 drop(runtime);
991
992 let _ = std::panic::take_hook();
993 std::panic::set_hook(previous_hook);
994
995 report_non_blocking_hook_outcome(
996 Err(join_error),
997 "on_session_start".into(),
998 "test command".into(),
999 reporter,
1000 );
1001
1002 let reported = reported.lock().unwrap();
1003 assert_eq!(reported.len(), 1);
1004 match &reported[0] {
1005 HookBackgroundEvent::NonBlockingHookPanicked {
1006 event,
1007 command,
1008 error,
1009 } => {
1010 assert_eq!(event, "on_session_start");
1011 assert_eq!(command, "test command");
1012 assert!(error.contains("panic") || error.contains("cancelled"));
1013 }
1014 other => panic!("expected non-blocking hook panic, got {other:?}"),
1015 }
1016 }
1017
1018 #[tokio::test]
1019 async fn hook_before_tool_call_blocks() {
1020 let mut runner = HookRunner::new();
1021 runner.load_from_config(vec![HookDef {
1022 event: "before_tool_call".into(),
1023 match_pattern: Some("bash".into()),
1024 action: "shell".into(),
1025 command: Some("exit 1".into()),
1026 blocking: true,
1027 threshold: None,
1028 }]);
1029
1030 let args = serde_json::json!({"command": "rm -rf /"});
1031 let event = HookEvent::BeforeToolCall {
1032 tool_name: "bash",
1033 args: &args,
1034 };
1035 let results = runner.fire(&event).await;
1036 assert_eq!(results.len(), 1);
1037 assert!(results[0].block);
1038 }
1039
1040 #[tokio::test]
1041 async fn hook_before_tool_call_allows() {
1042 let mut runner = HookRunner::new();
1043 runner.load_from_config(vec![HookDef {
1044 event: "before_tool_call".into(),
1045 match_pattern: Some("read".into()),
1046 action: "shell".into(),
1047 command: Some("exit 0".into()),
1048 blocking: true,
1049 threshold: None,
1050 }]);
1051
1052 let args = serde_json::json!({});
1053 let event = HookEvent::BeforeToolCall {
1054 tool_name: "read",
1055 args: &args,
1056 };
1057 let results = runner.fire(&event).await;
1058 assert_eq!(results.len(), 1);
1059 assert!(!results[0].block);
1060 }
1061
1062 #[tokio::test]
1063 async fn hook_after_tool_call_modifies_result() {
1064 let mut runner = HookRunner::new();
1065 runner.load_from_config(vec![HookDef {
1066 event: "after_tool_call".into(),
1067 match_pattern: None,
1068 action: "shell".into(),
1069 command: Some("echo modified output".into()),
1070 blocking: true,
1071 threshold: None,
1072 }]);
1073
1074 let result_msg = ToolResultMessage {
1075 tool_call_id: "call_1".into(),
1076 tool_name: "test".into(),
1077 content: vec![ContentBlock::Text {
1078 text: "original".into(),
1079 }],
1080 is_error: false,
1081 details: serde_json::Value::Null,
1082 timestamp: 0,
1083 };
1084 let event = HookEvent::AfterToolCall {
1085 tool_name: "test",
1086 result: &result_msg,
1087 };
1088 let results = runner.fire(&event).await;
1089 assert_eq!(results.len(), 1);
1090 let modified = results[0]
1091 .modified_content
1092 .as_ref()
1093 .expect("should have modified content");
1094 assert_eq!(modified.len(), 1);
1095 if let ContentBlock::Text { text } = &modified[0] {
1096 assert_eq!(text, "modified output");
1097 } else {
1098 panic!("expected Text content block");
1099 }
1100 }
1101
1102 #[tokio::test]
1103 async fn hook_context_threshold_fires_at_correct_ratio() {
1104 let mut runner = HookRunner::new();
1105 runner.load_from_config(vec![HookDef {
1106 event: "on_context_threshold".into(),
1107 match_pattern: None,
1108 action: "shell".into(),
1109 command: Some("echo threshold hit at {ratio}".into()),
1110 blocking: true,
1111 threshold: Some(0.8),
1112 }]);
1113
1114 let below = HookEvent::OnContextThreshold { ratio: 0.5 };
1116 let results = runner.fire(&below).await;
1117 assert!(results.is_empty());
1118
1119 let at = HookEvent::OnContextThreshold { ratio: 0.8 };
1121 let results = runner.fire(&at).await;
1122 assert_eq!(results.len(), 1);
1123
1124 let above = HookEvent::OnContextThreshold { ratio: 0.95 };
1126 let results = runner.fire(&above).await;
1127 assert_eq!(results.len(), 1);
1128 }
1129
1130 #[tokio::test]
1131 async fn hook_execution_order_toml_first_then_programmatic() {
1132 use std::sync::Mutex;
1133
1134 let order = Arc::new(Mutex::new(Vec::new()));
1135
1136 let mut runner = HookRunner::new();
1137
1138 runner.load_from_config(vec![HookDef {
1140 event: "on_session_start".into(),
1141 match_pattern: None,
1142 action: "shell".into(),
1143 command: Some("echo toml".into()),
1144 blocking: true,
1145 threshold: None,
1146 }]);
1147
1148 let order_clone = Arc::clone(&order);
1150 runner.register_callback(
1151 "on_session_start",
1152 Arc::new(move |_event| {
1153 order_clone.lock().unwrap().push("programmatic");
1154 HookResult::default()
1155 }),
1156 );
1157
1158 let event = HookEvent::OnSessionStart;
1159 let results = runner.fire(&event).await;
1160
1161 assert_eq!(results.len(), 2);
1163
1164 let recorded = order.lock().unwrap();
1166 assert_eq!(recorded.len(), 1);
1167 assert_eq!(recorded[0], "programmatic");
1168 }
1169
1170 #[tokio::test]
1171 async fn hook_callback_blocks_tool_call() {
1172 let mut runner = HookRunner::new();
1173 runner.register_callback(
1174 "before_tool_call",
1175 Arc::new(|_event| HookResult {
1176 block: true,
1177 reason: Some("blocked by callback".into()),
1178 modified_content: None,
1179 }),
1180 );
1181
1182 let args = serde_json::json!({});
1183 let event = HookEvent::BeforeToolCall {
1184 tool_name: "bash",
1185 args: &args,
1186 };
1187 let results = runner.fire(&event).await;
1188 assert_eq!(results.len(), 1);
1189 assert!(results[0].block);
1190 assert_eq!(results[0].reason.as_deref(), Some("blocked by callback"));
1191 }
1192
1193 #[tokio::test]
1194 async fn hook_shell_interpolation_in_execution() {
1195 let tmp = tempfile::NamedTempFile::new().unwrap();
1196 let tmp_path = tmp.path().to_path_buf();
1197 let marker_file = tempfile::NamedTempFile::new().unwrap();
1198 let marker_path = marker_file.path().to_string_lossy().to_string();
1199
1200 let mut runner = HookRunner::new();
1201 runner.load_from_config(vec![HookDef {
1202 event: "after_file_write".into(),
1203 match_pattern: None,
1204 action: "shell".into(),
1205 command: Some(format!("echo {{file}} > {marker_path}")),
1206 blocking: true,
1207 threshold: None,
1208 }]);
1209
1210 let event = HookEvent::AfterFileWrite { file: &tmp_path };
1211 runner.fire(&event).await;
1212
1213 let content = std::fs::read_to_string(&marker_path).unwrap();
1215 assert!(
1216 content.contains(&tmp_path.to_string_lossy().to_string()),
1217 "Expected marker to contain file path, got: {content}"
1218 );
1219 }
1220
1221 #[test]
1222 fn hook_runner_load_from_config_resolves_all() {
1223 let mut runner = HookRunner::new();
1224 runner.load_from_config(vec![
1225 HookDef {
1226 event: "after_file_write".into(),
1227 match_pattern: Some("*.rs".into()),
1228 action: "shell".into(),
1229 command: Some("rustfmt {file}".into()),
1230 blocking: true,
1231 threshold: None,
1232 },
1233 HookDef {
1234 event: "before_tool_call".into(),
1235 match_pattern: Some("bash".into()),
1236 action: "shell".into(),
1237 command: Some("echo checking".into()),
1238 blocking: true,
1239 threshold: None,
1240 },
1241 ]);
1242 assert_eq!(runner.toml_hooks.len(), 2);
1243 }
1244
1245 #[tokio::test]
1246 async fn hook_unmatched_event_returns_empty() {
1247 let mut runner = HookRunner::new();
1248 runner.load_from_config(vec![HookDef {
1249 event: "on_session_start".into(),
1250 match_pattern: None,
1251 action: "shell".into(),
1252 command: Some("echo hi".into()),
1253 blocking: true,
1254 threshold: None,
1255 }]);
1256
1257 let event = HookEvent::BeforeLlmCall;
1259 let results = runner.fire(&event).await;
1260 assert!(results.is_empty());
1261 }
1262}