Skip to main content

harn_vm/orchestration/
hooks.rs

1//! Tool lifecycle hooks — pre/post-execution interception.
2
3use std::cell::RefCell;
4use std::rc::Rc;
5
6/// Action returned by a PreToolUse hook.
7#[derive(Clone, Debug)]
8pub enum PreToolAction {
9    /// Allow the tool call to proceed unchanged.
10    Allow,
11    /// Deny the tool call with an explanation.
12    Deny(String),
13    /// Allow but replace the arguments.
14    Modify(serde_json::Value),
15}
16
17/// Action returned by a PostToolUse hook.
18#[derive(Clone, Debug)]
19pub enum PostToolAction {
20    /// Pass the result through unchanged.
21    Pass,
22    /// Replace the result text.
23    Modify(String),
24}
25
26/// Callback types for tool lifecycle hooks.
27pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
28pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
29
30/// A registered tool hook with a name pattern and callbacks.
31#[derive(Clone)]
32pub struct ToolHook {
33    /// Glob-style pattern matched against tool names (e.g. `"*"`, `"exec*"`, `"read_file"`).
34    pub pattern: String,
35    /// Called before tool execution. Return `Deny` to reject, `Modify` to rewrite args.
36    pub pre: Option<PreToolHookFn>,
37    /// Called after tool execution with the result text. Return `Modify` to rewrite.
38    pub post: Option<PostToolHookFn>,
39}
40
41impl std::fmt::Debug for ToolHook {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("ToolHook")
44            .field("pattern", &self.pattern)
45            .field("has_pre", &self.pre.is_some())
46            .field("has_post", &self.post.is_some())
47            .finish()
48    }
49}
50
51thread_local! {
52    pub(super) static TOOL_HOOKS: RefCell<Vec<ToolHook>> = const { RefCell::new(Vec::new()) };
53}
54
55pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
56    if pattern == "*" {
57        return true;
58    }
59    if let Some(prefix) = pattern.strip_suffix('*') {
60        return name.starts_with(prefix);
61    }
62    if let Some(suffix) = pattern.strip_prefix('*') {
63        return name.ends_with(suffix);
64    }
65    pattern == name
66}
67
68pub fn register_tool_hook(hook: ToolHook) {
69    TOOL_HOOKS.with(|hooks| hooks.borrow_mut().push(hook));
70}
71
72pub fn clear_tool_hooks() {
73    TOOL_HOOKS.with(|hooks| hooks.borrow_mut().clear());
74}
75
76/// Run all matching PreToolUse hooks. Returns the final action.
77pub fn run_pre_tool_hooks(tool_name: &str, args: &serde_json::Value) -> PreToolAction {
78    TOOL_HOOKS.with(|hooks| {
79        let hooks = hooks.borrow();
80        let mut current_args = args.clone();
81        for hook in hooks.iter() {
82            if !glob_match(&hook.pattern, tool_name) {
83                continue;
84            }
85            if let Some(ref pre) = hook.pre {
86                match pre(tool_name, &current_args) {
87                    PreToolAction::Allow => {}
88                    PreToolAction::Deny(reason) => return PreToolAction::Deny(reason),
89                    PreToolAction::Modify(new_args) => {
90                        current_args = new_args;
91                    }
92                }
93            }
94        }
95        if current_args != *args {
96            PreToolAction::Modify(current_args)
97        } else {
98            PreToolAction::Allow
99        }
100    })
101}
102
103/// Run all matching PostToolUse hooks. Returns the (possibly modified) result.
104pub fn run_post_tool_hooks(tool_name: &str, result: &str) -> String {
105    TOOL_HOOKS.with(|hooks| {
106        let hooks = hooks.borrow();
107        let mut current = result.to_string();
108        for hook in hooks.iter() {
109            if !glob_match(&hook.pattern, tool_name) {
110                continue;
111            }
112            if let Some(ref post) = hook.post {
113                match post(tool_name, &current) {
114                    PostToolAction::Pass => {}
115                    PostToolAction::Modify(new_result) => {
116                        current = new_result;
117                    }
118                }
119            }
120        }
121        current
122    })
123}