Skip to main content

imp_core/
hooks.rs

1use std::path::Path;
2use std::sync::Arc;
3
4use glob::Pattern;
5use imp_llm::{AssistantMessage, ContentBlock, Message, ToolResultMessage};
6use serde::{Deserialize, Serialize};
7use tokio::process::Command;
8
9/// Reports outcomes from background non-blocking hook execution.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum HookBackgroundEvent {
12    NonBlockingHookFailed {
13        event: String,
14        command: String,
15        error: String,
16    },
17    NonBlockingHookPanicked {
18        event: String,
19        command: String,
20        error: String,
21    },
22}
23
24impl std::fmt::Display for HookBackgroundEvent {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Self::NonBlockingHookFailed {
28                event,
29                command,
30                error,
31            } => write!(
32                f,
33                "Non-blocking hook failed for event '{event}' while running `{command}`: {error}"
34            ),
35            Self::NonBlockingHookPanicked {
36                event,
37                command,
38                error,
39            } => write!(
40                f,
41                "Non-blocking hook panicked for event '{event}' while running `{command}`: {error}"
42            ),
43        }
44    }
45}
46
47/// Hook definition from TOML config.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct HookDef {
50    pub event: String,
51    #[serde(rename = "match")]
52    pub match_pattern: Option<String>,
53    pub action: String,
54    pub command: Option<String>,
55    #[serde(default)]
56    pub blocking: bool,
57    pub threshold: Option<f64>,
58}
59
60/// What a hook does when triggered.
61#[derive(Clone)]
62pub enum HookAction {
63    /// Run a shell command with interpolation ({file}, {tool_name}).
64    Shell { command: String },
65    /// A programmatic callback (for Lua or other extensions).
66    Callback(Arc<dyn Fn(&HookEvent<'_>) -> HookResult + Send + Sync>),
67}
68
69impl std::fmt::Debug for HookAction {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            HookAction::Shell { command } => {
73                f.debug_struct("Shell").field("command", command).finish()
74            }
75            HookAction::Callback(_) => f.write_str("Callback(...)"),
76        }
77    }
78}
79
80/// A fully resolved hook definition ready for execution.
81#[derive(Debug, Clone)]
82pub struct HookDefinition {
83    pub event: String,
84    pub match_pattern: Option<String>,
85    pub action: HookAction,
86    pub blocking: bool,
87    pub threshold: Option<f64>,
88}
89
90/// Runtime hook events.
91#[derive(Clone)]
92pub enum HookEvent<'a> {
93    AfterFileWrite {
94        file: &'a Path,
95    },
96    BeforeToolCall {
97        tool_name: &'a str,
98        args: &'a serde_json::Value,
99    },
100    AfterToolCall {
101        tool_name: &'a str,
102        result: &'a ToolResultMessage,
103    },
104    BeforeLlmCall,
105    OnContextThreshold {
106        ratio: f64,
107    },
108    OnSessionStart,
109    OnSessionShutdown,
110    OnAgentStart {
111        prompt: &'a str,
112    },
113    OnAgentEnd {
114        messages: &'a [Message],
115    },
116    OnTurnEnd {
117        index: u32,
118        message: &'a AssistantMessage,
119    },
120}
121
122impl<'a> HookEvent<'a> {
123    /// Return the canonical event name for matching against hook definitions.
124    fn event_name(&self) -> &'static str {
125        match self {
126            HookEvent::AfterFileWrite { .. } => "after_file_write",
127            HookEvent::BeforeToolCall { .. } => "before_tool_call",
128            HookEvent::AfterToolCall { .. } => "after_tool_call",
129            HookEvent::BeforeLlmCall => "before_llm_call",
130            HookEvent::OnContextThreshold { .. } => "on_context_threshold",
131            HookEvent::OnSessionStart => "on_session_start",
132            HookEvent::OnSessionShutdown => "on_session_shutdown",
133            HookEvent::OnAgentStart { .. } => "on_agent_start",
134            HookEvent::OnAgentEnd { .. } => "on_agent_end",
135            HookEvent::OnTurnEnd { .. } => "on_turn_end",
136        }
137    }
138}
139
140/// Result from a hook execution.
141#[derive(Default, Debug)]
142pub struct HookResult {
143    pub block: bool,
144    pub reason: Option<String>,
145    pub modified_content: Option<Vec<ContentBlock>>,
146}
147
148/// Manages and executes hooks.
149pub struct HookRunner {
150    /// TOML-defined hooks (fire first, in config order).
151    toml_hooks: Vec<HookDefinition>,
152    /// Programmatically registered hooks (fire after TOML hooks, in registration order).
153    programmatic_hooks: Vec<HookDefinition>,
154    /// Optional observer for background non-blocking hook failures.
155    background_reporter: Option<Arc<dyn Fn(HookBackgroundEvent) + Send + Sync>>,
156}
157
158impl HookRunner {
159    pub fn new() -> Self {
160        Self {
161            toml_hooks: Vec::new(),
162            programmatic_hooks: Vec::new(),
163            background_reporter: None,
164        }
165    }
166
167    /// Add a single TOML hook def (raw from config).
168    pub fn add(&mut self, def: HookDef) {
169        if let Some(resolved) = resolve_hook_def(def) {
170            self.toml_hooks.push(resolved);
171        }
172    }
173
174    /// Load multiple TOML hook defs from config.
175    pub fn load_from_config(&mut self, defs: Vec<HookDef>) {
176        for def in defs {
177            self.add(def);
178        }
179    }
180
181    /// Register a programmatic hook (for Lua or other extensions).
182    pub fn register(&mut self, hook: HookDefinition) {
183        self.programmatic_hooks.push(hook);
184    }
185
186    /// Returns the total number of registered hooks (TOML + programmatic).
187    pub fn len(&self) -> usize {
188        self.toml_hooks.len() + self.programmatic_hooks.len()
189    }
190
191    /// Returns true if no hooks are registered.
192    pub fn is_empty(&self) -> bool {
193        self.toml_hooks.is_empty() && self.programmatic_hooks.is_empty()
194    }
195
196    /// Register an observer for background non-blocking hook failures.
197    pub fn set_background_reporter(
198        &mut self,
199        reporter: Arc<dyn Fn(HookBackgroundEvent) + Send + Sync>,
200    ) {
201        self.background_reporter = Some(reporter);
202    }
203
204    /// Register a callback hook for a specific event.
205    pub fn register_callback(
206        &mut self,
207        event: &str,
208        callback: Arc<dyn Fn(&HookEvent<'_>) -> HookResult + Send + Sync>,
209    ) {
210        self.programmatic_hooks.push(HookDefinition {
211            event: event.to_string(),
212            match_pattern: None,
213            action: HookAction::Callback(callback),
214            blocking: true,
215            threshold: None,
216        });
217    }
218
219    /// Fire a hook event and collect results.
220    ///
221    /// Execution order: TOML hooks first (config order), then programmatic hooks (registration order).
222    /// Blocking hooks execute sequentially and await completion.
223    /// Non-blocking hooks are spawned as background tokio tasks.
224    pub async fn fire(&self, event: &HookEvent<'_>) -> Vec<HookResult> {
225        let mut results = Vec::new();
226
227        // TOML hooks first, then programmatic hooks
228        let all_hooks = self.toml_hooks.iter().chain(self.programmatic_hooks.iter());
229
230        for hook in all_hooks {
231            if !matches_event(hook, event) {
232                continue;
233            }
234
235            if hook.blocking {
236                let result = execute_hook(hook, event).await;
237                results.push(result);
238            } else {
239                // Keep non-blocking hooks asynchronous, but supervise failures.
240                if let HookAction::Shell { command } = &hook.action {
241                    let cmd = interpolate_command(command, event);
242                    run_non_blocking_shell_hook(
243                        hook_event_label(event),
244                        cmd,
245                        self.background_reporter.clone(),
246                    );
247                }
248                // Non-blocking hooks don't contribute results
249            }
250        }
251
252        results
253    }
254}
255
256impl Default for HookRunner {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262fn hook_event_label(event: &HookEvent<'_>) -> String {
263    event.event_name().to_string()
264}
265
266fn report_non_blocking_hook_outcome(
267    join_result: Result<std::io::Result<std::process::Output>, tokio::task::JoinError>,
268    event_name: String,
269    command_for_report: String,
270    reporter: Arc<dyn Fn(HookBackgroundEvent) + Send + Sync>,
271) {
272    match join_result {
273        Ok(Ok(output)) => {
274            if !output.status.success() {
275                let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
276                let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
277                let error = if !stderr.is_empty() {
278                    stderr
279                } else if !stdout.is_empty() {
280                    stdout
281                } else {
282                    format!(
283                        "command exited with status {}",
284                        output
285                            .status
286                            .code()
287                            .map(|code| code.to_string())
288                            .unwrap_or_else(|| "terminated by signal".into())
289                    )
290                };
291                reporter(HookBackgroundEvent::NonBlockingHookFailed {
292                    event: event_name,
293                    command: command_for_report,
294                    error,
295                });
296            }
297        }
298        Ok(Err(error)) => reporter(HookBackgroundEvent::NonBlockingHookFailed {
299            event: event_name,
300            command: command_for_report,
301            error: error.to_string(),
302        }),
303        Err(join_error) => reporter(HookBackgroundEvent::NonBlockingHookPanicked {
304            event: event_name,
305            command: command_for_report,
306            error: join_error.to_string(),
307        }),
308    }
309}
310
311fn run_non_blocking_shell_hook(
312    event_name: String,
313    command: String,
314    reporter: Option<Arc<dyn Fn(HookBackgroundEvent) + Send + Sync>>,
315) {
316    tokio::spawn(async move {
317        let command_for_run = command.clone();
318        let command_for_report = command;
319        let join_result = tokio::spawn(async move {
320            Command::new("sh")
321                .arg("-c")
322                .arg(&command_for_run)
323                .stdin(std::process::Stdio::null())
324                .output()
325                .await
326        })
327        .await;
328
329        if let Some(reporter) = reporter {
330            report_non_blocking_hook_outcome(join_result, event_name, command_for_report, reporter);
331        }
332    });
333}
334
335fn resolve_hook_def(def: HookDef) -> Option<HookDefinition> {
336    let action = match def.action.as_str() {
337        "shell" => {
338            let command = def.command?;
339            HookAction::Shell { command }
340        }
341        _ => return None,
342    };
343
344    Some(HookDefinition {
345        event: def.event,
346        match_pattern: def.match_pattern,
347        action,
348        blocking: def.blocking,
349        threshold: def.threshold,
350    })
351}
352
353/// Check if a hook definition matches the given event.
354fn matches_event(hook: &HookDefinition, event: &HookEvent<'_>) -> bool {
355    // Event name must match
356    if hook.event != event.event_name() {
357        return false;
358    }
359
360    // Check match_pattern if present
361    if let Some(pattern) = &hook.match_pattern {
362        match event {
363            HookEvent::AfterFileWrite { file } => {
364                let file_str = file.to_string_lossy();
365                // Try glob matching against the full path and filename
366                if let Ok(glob) = Pattern::new(pattern) {
367                    let file_name = file
368                        .file_name()
369                        .map(|n| n.to_string_lossy().to_string())
370                        .unwrap_or_default();
371                    if !glob.matches(&file_str) && !glob.matches(&file_name) {
372                        return false;
373                    }
374                } else {
375                    return false;
376                }
377            }
378            HookEvent::BeforeToolCall { tool_name, .. }
379            | HookEvent::AfterToolCall { tool_name, .. } => {
380                if pattern != *tool_name {
381                    // Also try glob matching on tool name
382                    if let Ok(glob) = Pattern::new(pattern) {
383                        if !glob.matches(tool_name) {
384                            return false;
385                        }
386                    } else {
387                        return false;
388                    }
389                }
390            }
391            _ => {
392                // Other events ignore match_pattern
393            }
394        }
395    }
396
397    // Check threshold for OnContextThreshold
398    if let HookEvent::OnContextThreshold { ratio } = event {
399        if let Some(threshold) = hook.threshold {
400            if *ratio < threshold {
401                return false;
402            }
403        }
404    }
405
406    true
407}
408
409/// Interpolate variables into a shell command string.
410fn interpolate_command(command: &str, event: &HookEvent<'_>) -> String {
411    let mut result = command.to_string();
412
413    match event {
414        HookEvent::AfterFileWrite { file } => {
415            result = replace_placeholder(&result, "file", &file.to_string_lossy());
416        }
417        HookEvent::BeforeToolCall { tool_name, .. } => {
418            result = replace_placeholder(&result, "tool_name", tool_name);
419        }
420        HookEvent::AfterToolCall {
421            tool_name,
422            result: tool_result,
423        } => {
424            result = replace_placeholder(&result, "tool_name", tool_name);
425            result = replace_placeholder(
426                &result,
427                "is_error",
428                if tool_result.is_error {
429                    "true"
430                } else {
431                    "false"
432                },
433            );
434            // Extract exit_code from details if present (bash tool sets this)
435            let exit_code = tool_result
436                .details
437                .get("exit_code")
438                .and_then(|v| v.as_i64())
439                .map(|c| c.to_string())
440                .unwrap_or_default();
441            result = replace_placeholder(&result, "exit_code", &exit_code);
442            // First line of output for summary
443            let output_first = tool_result
444                .content
445                .iter()
446                .filter_map(|b| match b {
447                    imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
448                    _ => None,
449                })
450                .next()
451                .and_then(|t| t.lines().next())
452                .unwrap_or("");
453            result = replace_placeholder(&result, "output_first_line", output_first);
454            // Extract command from details (bash tool stores it)
455            let command = tool_result
456                .details
457                .get("command")
458                .and_then(|v| v.as_str())
459                .unwrap_or("");
460            result = replace_placeholder(&result, "command", command);
461        }
462        HookEvent::OnContextThreshold { ratio } => {
463            result = replace_placeholder(&result, "ratio", &ratio.to_string());
464        }
465        HookEvent::OnTurnEnd { index, .. } => {
466            result = replace_placeholder(&result, "index", &index.to_string());
467        }
468        _ => {}
469    }
470
471    result
472}
473
474fn replace_placeholder(template: &str, name: &str, value: &str) -> String {
475    let raw = format!("{{{name}}}");
476    let single_marker = format!("\u{0}__imp_hook_single_{name}__\u{0}");
477    let double_marker = format!("\u{0}__imp_hook_double_{name}__\u{0}");
478
479    let mut result = template.replace(&format!("'{raw}'"), &single_marker);
480    result = result.replace(&format!("\"{raw}\""), &double_marker);
481    result = result.replace(&raw, value);
482    result = result.replace(&single_marker, &shell_single_quote(value));
483    result = result.replace(&double_marker, &shell_double_quote(value));
484    result
485}
486
487fn shell_single_quote(value: &str) -> String {
488    format!("'{}'", value.replace('\'', "'\\''"))
489}
490
491fn shell_double_quote(value: &str) -> String {
492    let mut escaped = String::with_capacity(value.len());
493    for ch in value.chars() {
494        match ch {
495            '\\' | '"' | '$' | '`' => {
496                escaped.push('\\');
497                escaped.push(ch);
498            }
499            _ => escaped.push(ch),
500        }
501    }
502    format!("\"{escaped}\"")
503}
504
505/// Execute a single hook and return its result.
506async fn execute_hook(hook: &HookDefinition, event: &HookEvent<'_>) -> HookResult {
507    match &hook.action {
508        HookAction::Shell { command } => {
509            let cmd = interpolate_command(command, event);
510            match Command::new("sh")
511                .arg("-c")
512                .arg(&cmd)
513                .stdin(std::process::Stdio::null())
514                .output()
515                .await
516            {
517                Ok(output) => {
518                    let stdout = String::from_utf8_lossy(&output.stdout).to_string();
519                    let stderr = String::from_utf8_lossy(&output.stderr).to_string();
520
521                    // A non-zero exit code on a BeforeToolCall hook means "block"
522                    let block = matches!(event, HookEvent::BeforeToolCall { .. })
523                        && !output.status.success();
524
525                    let reason = if block {
526                        Some(if stderr.is_empty() {
527                            stdout.clone()
528                        } else {
529                            stderr
530                        })
531                    } else {
532                        None
533                    };
534
535                    // For AfterToolCall, stdout is treated as modified content
536                    let modified_content = if matches!(event, HookEvent::AfterToolCall { .. })
537                        && !stdout.trim().is_empty()
538                        && output.status.success()
539                    {
540                        Some(vec![ContentBlock::Text {
541                            text: stdout.trim().to_string(),
542                        }])
543                    } else {
544                        None
545                    };
546
547                    HookResult {
548                        block,
549                        reason,
550                        modified_content,
551                    }
552                }
553                Err(e) => HookResult {
554                    block: false,
555                    reason: Some(format!("Hook command failed: {e}")),
556                    modified_content: None,
557                },
558            }
559        }
560        HookAction::Callback(cb) => cb(event),
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use std::path::PathBuf;
568    use std::sync::Mutex;
569
570    #[test]
571    fn hook_def_toml_parsing() {
572        let toml_str = r#"
573[[hooks]]
574event = "after_file_write"
575match = "*.rs"
576action = "shell"
577command = "rustfmt {file}"
578blocking = true
579
580[[hooks]]
581event = "on_context_threshold"
582action = "shell"
583command = "echo threshold"
584threshold = 0.8
585"#;
586
587        #[derive(Deserialize)]
588        struct Wrapper {
589            hooks: Vec<HookDef>,
590        }
591
592        let parsed: Wrapper = toml::from_str(toml_str).expect("TOML parsing failed");
593        assert_eq!(parsed.hooks.len(), 2);
594
595        let h0 = &parsed.hooks[0];
596        assert_eq!(h0.event, "after_file_write");
597        assert_eq!(h0.match_pattern.as_deref(), Some("*.rs"));
598        assert_eq!(h0.action, "shell");
599        assert_eq!(h0.command.as_deref(), Some("rustfmt {file}"));
600        assert!(h0.blocking);
601        assert!(h0.threshold.is_none());
602
603        let h1 = &parsed.hooks[1];
604        assert_eq!(h1.event, "on_context_threshold");
605        assert!(h1.match_pattern.is_none());
606        assert_eq!(h1.threshold, Some(0.8));
607    }
608
609    #[test]
610    fn hook_interpolation_file() {
611        let event = HookEvent::AfterFileWrite {
612            file: Path::new("/tmp/test.rs"),
613        };
614        let result = interpolate_command("rustfmt {file}", &event);
615        assert_eq!(result, "rustfmt /tmp/test.rs");
616    }
617
618    #[test]
619    fn hook_interpolation_tool_name() {
620        let args = serde_json::json!({"path": "/tmp"});
621        let event = HookEvent::BeforeToolCall {
622            tool_name: "bash",
623            args: &args,
624        };
625        let result = interpolate_command("echo {tool_name}", &event);
626        assert_eq!(result, "echo bash");
627    }
628
629    #[test]
630    fn hook_interpolation_quoted_placeholder() {
631        let command_text = "pwd && egrep '(^|/)(README|VISION)\\.md$' && printf '$HOME'";
632        let result_msg = ToolResultMessage {
633            tool_call_id: "call_quoted".into(),
634            tool_name: "bash".into(),
635            content: vec![ContentBlock::Text { text: "ok".into() }],
636            is_error: true,
637            details: serde_json::json!({
638                "exit_code": 2,
639                "command": command_text,
640            }),
641            timestamp: 0,
642        };
643        let event = HookEvent::AfterToolCall {
644            tool_name: "bash",
645            result: &result_msg,
646        };
647
648        let interpolated = interpolate_command(
649            "hook '{is_error}' '{exit_code}' '{command}' \"{command}\" {command}",
650            &event,
651        );
652
653        assert_eq!(
654            interpolated,
655            format!(
656                "hook 'true' '2' {} {} {}",
657                shell_single_quote(command_text),
658                shell_double_quote(command_text),
659                command_text
660            )
661        );
662    }
663
664    #[test]
665    fn hook_interpolation_ratio() {
666        let event = HookEvent::OnContextThreshold { ratio: 0.75 };
667        let result = interpolate_command("echo ratio={ratio}", &event);
668        assert_eq!(result, "echo ratio=0.75");
669    }
670
671    #[test]
672    fn hook_event_name_mapping() {
673        let path = PathBuf::from("/tmp/test.rs");
674        assert_eq!(
675            HookEvent::AfterFileWrite { file: &path }.event_name(),
676            "after_file_write"
677        );
678        assert_eq!(HookEvent::BeforeLlmCall.event_name(), "before_llm_call");
679        assert_eq!(HookEvent::OnSessionStart.event_name(), "on_session_start");
680        assert_eq!(
681            HookEvent::OnSessionShutdown.event_name(),
682            "on_session_shutdown"
683        );
684        assert_eq!(
685            HookEvent::OnContextThreshold { ratio: 0.5 }.event_name(),
686            "on_context_threshold"
687        );
688    }
689
690    #[test]
691    fn hook_matches_event_name() {
692        let hook = HookDefinition {
693            event: "after_file_write".into(),
694            match_pattern: None,
695            action: HookAction::Shell {
696                command: "echo hi".into(),
697            },
698            blocking: false,
699            threshold: None,
700        };
701        let path = PathBuf::from("/tmp/test.rs");
702        let event = HookEvent::AfterFileWrite { file: &path };
703        assert!(matches_event(&hook, &event));
704
705        let wrong_event = HookEvent::BeforeLlmCall;
706        assert!(!matches_event(&hook, &wrong_event));
707    }
708
709    #[test]
710    fn hook_matches_file_glob() {
711        let hook = HookDefinition {
712            event: "after_file_write".into(),
713            match_pattern: Some("*.rs".into()),
714            action: HookAction::Shell {
715                command: "echo hi".into(),
716            },
717            blocking: false,
718            threshold: None,
719        };
720
721        let rs_path = PathBuf::from("/tmp/test.rs");
722        let rs_event = HookEvent::AfterFileWrite { file: &rs_path };
723        assert!(matches_event(&hook, &rs_event));
724
725        let py_path = PathBuf::from("/tmp/test.py");
726        let py_event = HookEvent::AfterFileWrite { file: &py_path };
727        assert!(!matches_event(&hook, &py_event));
728    }
729
730    #[test]
731    fn hook_matches_tool_name() {
732        let hook = HookDefinition {
733            event: "before_tool_call".into(),
734            match_pattern: Some("bash".into()),
735            action: HookAction::Shell {
736                command: "echo hi".into(),
737            },
738            blocking: true,
739            threshold: None,
740        };
741
742        let args = serde_json::json!({});
743        let match_event = HookEvent::BeforeToolCall {
744            tool_name: "bash",
745            args: &args,
746        };
747        assert!(matches_event(&hook, &match_event));
748
749        let no_match_event = HookEvent::BeforeToolCall {
750            tool_name: "read",
751            args: &args,
752        };
753        assert!(!matches_event(&hook, &no_match_event));
754    }
755
756    #[test]
757    fn hook_threshold_filtering() {
758        let hook = HookDefinition {
759            event: "on_context_threshold".into(),
760            match_pattern: None,
761            action: HookAction::Shell {
762                command: "echo hi".into(),
763            },
764            blocking: true,
765            threshold: Some(0.8),
766        };
767
768        // Below threshold — should not match
769        let below = HookEvent::OnContextThreshold { ratio: 0.5 };
770        assert!(!matches_event(&hook, &below));
771
772        // At threshold — should match
773        let at = HookEvent::OnContextThreshold { ratio: 0.8 };
774        assert!(matches_event(&hook, &at));
775
776        // Above threshold — should match
777        let above = HookEvent::OnContextThreshold { ratio: 0.95 };
778        assert!(matches_event(&hook, &above));
779    }
780
781    #[test]
782    fn hook_resolve_shell() {
783        let def = HookDef {
784            event: "after_file_write".into(),
785            match_pattern: Some("*.rs".into()),
786            action: "shell".into(),
787            command: Some("rustfmt {file}".into()),
788            blocking: true,
789            threshold: None,
790        };
791        let resolved = resolve_hook_def(def).expect("should resolve");
792        assert_eq!(resolved.event, "after_file_write");
793        assert!(resolved.blocking);
794        assert!(matches!(resolved.action, HookAction::Shell { .. }));
795    }
796
797    #[test]
798    fn hook_resolve_missing_command_returns_none() {
799        let def = HookDef {
800            event: "after_file_write".into(),
801            match_pattern: None,
802            action: "shell".into(),
803            command: None,
804            blocking: false,
805            threshold: None,
806        };
807        assert!(resolve_hook_def(def).is_none());
808    }
809
810    #[test]
811    fn hook_resolve_unknown_action_returns_none() {
812        let def = HookDef {
813            event: "after_file_write".into(),
814            match_pattern: None,
815            action: "unknown".into(),
816            command: Some("echo".into()),
817            blocking: false,
818            threshold: None,
819        };
820        assert!(resolve_hook_def(def).is_none());
821    }
822
823    #[tokio::test]
824    async fn hook_blocking_shell_executes() {
825        let mut runner = HookRunner::new();
826        runner.load_from_config(vec![HookDef {
827            event: "after_file_write".into(),
828            match_pattern: None,
829            action: "shell".into(),
830            command: Some("echo hello".into()),
831            blocking: true,
832            threshold: None,
833        }]);
834
835        let path = PathBuf::from("/tmp/test.txt");
836        let event = HookEvent::AfterFileWrite { file: &path };
837        let results = runner.fire(&event).await;
838        assert_eq!(results.len(), 1);
839        assert!(!results[0].block);
840    }
841
842    #[tokio::test]
843    async fn hook_non_blocking_fires_and_forgets() {
844        let mut runner = HookRunner::new();
845        runner.load_from_config(vec![HookDef {
846            event: "on_session_start".into(),
847            match_pattern: None,
848            action: "shell".into(),
849            command: Some("echo non-blocking".into()),
850            blocking: false,
851            threshold: None,
852        }]);
853
854        let event = HookEvent::OnSessionStart;
855        let started = std::time::Instant::now();
856        let results = runner.fire(&event).await;
857        // Non-blocking hooks don't return results
858        assert!(results.is_empty());
859        assert!(started.elapsed() < std::time::Duration::from_secs(1));
860    }
861
862    #[tokio::test]
863    async fn hook_non_blocking_failure_is_reported() {
864        let mut runner = HookRunner::new();
865        let reported = Arc::new(Mutex::new(Vec::new()));
866        let reported_clone = Arc::clone(&reported);
867        runner.set_background_reporter(Arc::new(move |event| {
868            reported_clone.lock().unwrap().push(event);
869        }));
870        runner.load_from_config(vec![HookDef {
871            event: "on_session_start".into(),
872            match_pattern: None,
873            action: "shell".into(),
874            command: Some("exit 7".into()),
875            blocking: false,
876            threshold: None,
877        }]);
878
879        let event = HookEvent::OnSessionStart;
880        let results = runner.fire(&event).await;
881        assert!(results.is_empty());
882
883        for _ in 0..20 {
884            if !reported.lock().unwrap().is_empty() {
885                break;
886            }
887            tokio::time::sleep(std::time::Duration::from_millis(25)).await;
888        }
889
890        let reported = reported.lock().unwrap();
891        assert_eq!(reported.len(), 1);
892        match &reported[0] {
893            HookBackgroundEvent::NonBlockingHookFailed { event, command, .. } => {
894                assert_eq!(event, "on_session_start");
895                assert_eq!(command, "exit 7");
896            }
897            other => panic!("expected non-blocking hook failure, got {other:?}"),
898        }
899    }
900
901    #[tokio::test]
902    async fn hook_after_tool_call_nonblocking_quoted_command() {
903        let temp = tempfile::tempdir().unwrap();
904        let output_path = temp.path().join("hook-args.txt");
905        let script_path = temp.path().join("capture.sh");
906        std::fs::write(
907            &script_path,
908            format!(
909                "#!/bin/sh\nprintf '%s\\n%s\\n%s\\n' \"$1\" \"$2\" \"$3\" > {}\n",
910                output_path.display()
911            ),
912        )
913        .unwrap();
914        #[cfg(unix)]
915        {
916            use std::os::unix::fs::PermissionsExt;
917            let mut perms = std::fs::metadata(&script_path).unwrap().permissions();
918            perms.set_mode(0o755);
919            std::fs::set_permissions(&script_path, perms).unwrap();
920        }
921
922        let mut runner = HookRunner::new();
923        runner.load_from_config(vec![HookDef {
924            event: "after_tool_call".into(),
925            match_pattern: Some("bash".into()),
926            action: "shell".into(),
927            command: Some(format!(
928                "{} '{{is_error}}' '{{exit_code}}' '{{command}}'",
929                script_path.display()
930            )),
931            blocking: false,
932            threshold: None,
933        }]);
934
935        let original_command = "pwd && egrep '(^|/)(README|VISION)\\.md$' | sort && printf '$HOME'";
936        let result_msg = ToolResultMessage {
937            tool_call_id: "call_1".into(),
938            tool_name: "bash".into(),
939            content: vec![ContentBlock::Text {
940                text: "failed".into(),
941            }],
942            is_error: true,
943            details: serde_json::json!({
944                "exit_code": 2,
945                "command": original_command,
946            }),
947            timestamp: 0,
948        };
949        let event = HookEvent::AfterToolCall {
950            tool_name: "bash",
951            result: &result_msg,
952        };
953
954        let results = runner.fire(&event).await;
955        assert!(results.is_empty());
956
957        for _ in 0..40 {
958            if output_path.exists() {
959                break;
960            }
961            tokio::time::sleep(std::time::Duration::from_millis(25)).await;
962        }
963
964        let captured = std::fs::read_to_string(&output_path).unwrap();
965        let mut lines = captured.lines();
966        assert_eq!(lines.next(), Some("true"));
967        assert_eq!(lines.next(), Some("2"));
968        assert_eq!(lines.next(), Some(original_command));
969    }
970
971    #[test]
972    fn report_non_blocking_hook_outcome_maps_join_failure_to_panic_event() {
973        let reported = Arc::new(Mutex::new(Vec::new()));
974        let reported_clone = Arc::clone(&reported);
975        let reporter: Arc<dyn Fn(HookBackgroundEvent) + Send + Sync> = Arc::new(move |event| {
976            reported_clone.lock().unwrap().push(event);
977        });
978
979        let previous_hook = std::panic::take_hook();
980        std::panic::set_hook(Box::new(|_| {}));
981
982        let runtime = tokio::runtime::Runtime::new().unwrap();
983        let join_error = runtime.block_on(async {
984            tokio::spawn(async move {
985                panic!("intentional join failure for reporting test");
986            })
987            .await
988            .unwrap_err()
989        });
990        drop(runtime);
991
992        let _ = std::panic::take_hook();
993        std::panic::set_hook(previous_hook);
994
995        report_non_blocking_hook_outcome(
996            Err(join_error),
997            "on_session_start".into(),
998            "test command".into(),
999            reporter,
1000        );
1001
1002        let reported = reported.lock().unwrap();
1003        assert_eq!(reported.len(), 1);
1004        match &reported[0] {
1005            HookBackgroundEvent::NonBlockingHookPanicked {
1006                event,
1007                command,
1008                error,
1009            } => {
1010                assert_eq!(event, "on_session_start");
1011                assert_eq!(command, "test command");
1012                assert!(error.contains("panic") || error.contains("cancelled"));
1013            }
1014            other => panic!("expected non-blocking hook panic, got {other:?}"),
1015        }
1016    }
1017
1018    #[tokio::test]
1019    async fn hook_before_tool_call_blocks() {
1020        let mut runner = HookRunner::new();
1021        runner.load_from_config(vec![HookDef {
1022            event: "before_tool_call".into(),
1023            match_pattern: Some("bash".into()),
1024            action: "shell".into(),
1025            command: Some("exit 1".into()),
1026            blocking: true,
1027            threshold: None,
1028        }]);
1029
1030        let args = serde_json::json!({"command": "rm -rf /"});
1031        let event = HookEvent::BeforeToolCall {
1032            tool_name: "bash",
1033            args: &args,
1034        };
1035        let results = runner.fire(&event).await;
1036        assert_eq!(results.len(), 1);
1037        assert!(results[0].block);
1038    }
1039
1040    #[tokio::test]
1041    async fn hook_before_tool_call_allows() {
1042        let mut runner = HookRunner::new();
1043        runner.load_from_config(vec![HookDef {
1044            event: "before_tool_call".into(),
1045            match_pattern: Some("read".into()),
1046            action: "shell".into(),
1047            command: Some("exit 0".into()),
1048            blocking: true,
1049            threshold: None,
1050        }]);
1051
1052        let args = serde_json::json!({});
1053        let event = HookEvent::BeforeToolCall {
1054            tool_name: "read",
1055            args: &args,
1056        };
1057        let results = runner.fire(&event).await;
1058        assert_eq!(results.len(), 1);
1059        assert!(!results[0].block);
1060    }
1061
1062    #[tokio::test]
1063    async fn hook_after_tool_call_modifies_result() {
1064        let mut runner = HookRunner::new();
1065        runner.load_from_config(vec![HookDef {
1066            event: "after_tool_call".into(),
1067            match_pattern: None,
1068            action: "shell".into(),
1069            command: Some("echo modified output".into()),
1070            blocking: true,
1071            threshold: None,
1072        }]);
1073
1074        let result_msg = ToolResultMessage {
1075            tool_call_id: "call_1".into(),
1076            tool_name: "test".into(),
1077            content: vec![ContentBlock::Text {
1078                text: "original".into(),
1079            }],
1080            is_error: false,
1081            details: serde_json::Value::Null,
1082            timestamp: 0,
1083        };
1084        let event = HookEvent::AfterToolCall {
1085            tool_name: "test",
1086            result: &result_msg,
1087        };
1088        let results = runner.fire(&event).await;
1089        assert_eq!(results.len(), 1);
1090        let modified = results[0]
1091            .modified_content
1092            .as_ref()
1093            .expect("should have modified content");
1094        assert_eq!(modified.len(), 1);
1095        if let ContentBlock::Text { text } = &modified[0] {
1096            assert_eq!(text, "modified output");
1097        } else {
1098            panic!("expected Text content block");
1099        }
1100    }
1101
1102    #[tokio::test]
1103    async fn hook_context_threshold_fires_at_correct_ratio() {
1104        let mut runner = HookRunner::new();
1105        runner.load_from_config(vec![HookDef {
1106            event: "on_context_threshold".into(),
1107            match_pattern: None,
1108            action: "shell".into(),
1109            command: Some("echo threshold hit at {ratio}".into()),
1110            blocking: true,
1111            threshold: Some(0.8),
1112        }]);
1113
1114        // Below threshold — no results
1115        let below = HookEvent::OnContextThreshold { ratio: 0.5 };
1116        let results = runner.fire(&below).await;
1117        assert!(results.is_empty());
1118
1119        // At threshold — should fire
1120        let at = HookEvent::OnContextThreshold { ratio: 0.8 };
1121        let results = runner.fire(&at).await;
1122        assert_eq!(results.len(), 1);
1123
1124        // Above threshold — should fire
1125        let above = HookEvent::OnContextThreshold { ratio: 0.95 };
1126        let results = runner.fire(&above).await;
1127        assert_eq!(results.len(), 1);
1128    }
1129
1130    #[tokio::test]
1131    async fn hook_execution_order_toml_first_then_programmatic() {
1132        use std::sync::Mutex;
1133
1134        let order = Arc::new(Mutex::new(Vec::new()));
1135
1136        let mut runner = HookRunner::new();
1137
1138        // TOML hook
1139        runner.load_from_config(vec![HookDef {
1140            event: "on_session_start".into(),
1141            match_pattern: None,
1142            action: "shell".into(),
1143            command: Some("echo toml".into()),
1144            blocking: true,
1145            threshold: None,
1146        }]);
1147
1148        // Programmatic hook
1149        let order_clone = Arc::clone(&order);
1150        runner.register_callback(
1151            "on_session_start",
1152            Arc::new(move |_event| {
1153                order_clone.lock().unwrap().push("programmatic");
1154                HookResult::default()
1155            }),
1156        );
1157
1158        let event = HookEvent::OnSessionStart;
1159        let results = runner.fire(&event).await;
1160
1161        // Both should fire
1162        assert_eq!(results.len(), 2);
1163
1164        // Programmatic should have recorded its execution
1165        let recorded = order.lock().unwrap();
1166        assert_eq!(recorded.len(), 1);
1167        assert_eq!(recorded[0], "programmatic");
1168    }
1169
1170    #[tokio::test]
1171    async fn hook_callback_blocks_tool_call() {
1172        let mut runner = HookRunner::new();
1173        runner.register_callback(
1174            "before_tool_call",
1175            Arc::new(|_event| HookResult {
1176                block: true,
1177                reason: Some("blocked by callback".into()),
1178                modified_content: None,
1179            }),
1180        );
1181
1182        let args = serde_json::json!({});
1183        let event = HookEvent::BeforeToolCall {
1184            tool_name: "bash",
1185            args: &args,
1186        };
1187        let results = runner.fire(&event).await;
1188        assert_eq!(results.len(), 1);
1189        assert!(results[0].block);
1190        assert_eq!(results[0].reason.as_deref(), Some("blocked by callback"));
1191    }
1192
1193    #[tokio::test]
1194    async fn hook_shell_interpolation_in_execution() {
1195        let tmp = tempfile::NamedTempFile::new().unwrap();
1196        let tmp_path = tmp.path().to_path_buf();
1197        let marker_file = tempfile::NamedTempFile::new().unwrap();
1198        let marker_path = marker_file.path().to_string_lossy().to_string();
1199
1200        let mut runner = HookRunner::new();
1201        runner.load_from_config(vec![HookDef {
1202            event: "after_file_write".into(),
1203            match_pattern: None,
1204            action: "shell".into(),
1205            command: Some(format!("echo {{file}} > {marker_path}")),
1206            blocking: true,
1207            threshold: None,
1208        }]);
1209
1210        let event = HookEvent::AfterFileWrite { file: &tmp_path };
1211        runner.fire(&event).await;
1212
1213        // Verify the marker file contains the interpolated path
1214        let content = std::fs::read_to_string(&marker_path).unwrap();
1215        assert!(
1216            content.contains(&tmp_path.to_string_lossy().to_string()),
1217            "Expected marker to contain file path, got: {content}"
1218        );
1219    }
1220
1221    #[test]
1222    fn hook_runner_load_from_config_resolves_all() {
1223        let mut runner = HookRunner::new();
1224        runner.load_from_config(vec![
1225            HookDef {
1226                event: "after_file_write".into(),
1227                match_pattern: Some("*.rs".into()),
1228                action: "shell".into(),
1229                command: Some("rustfmt {file}".into()),
1230                blocking: true,
1231                threshold: None,
1232            },
1233            HookDef {
1234                event: "before_tool_call".into(),
1235                match_pattern: Some("bash".into()),
1236                action: "shell".into(),
1237                command: Some("echo checking".into()),
1238                blocking: true,
1239                threshold: None,
1240            },
1241        ]);
1242        assert_eq!(runner.toml_hooks.len(), 2);
1243    }
1244
1245    #[tokio::test]
1246    async fn hook_unmatched_event_returns_empty() {
1247        let mut runner = HookRunner::new();
1248        runner.load_from_config(vec![HookDef {
1249            event: "on_session_start".into(),
1250            match_pattern: None,
1251            action: "shell".into(),
1252            command: Some("echo hi".into()),
1253            blocking: true,
1254            threshold: None,
1255        }]);
1256
1257        // Fire a different event
1258        let event = HookEvent::BeforeLlmCall;
1259        let results = runner.fire(&event).await;
1260        assert!(results.is_empty());
1261    }
1262}