harn_vm/orchestration/
hooks.rs1use std::cell::RefCell;
4use std::rc::Rc;
5
6#[derive(Clone, Debug)]
8pub enum PreToolAction {
9 Allow,
11 Deny(String),
13 Modify(serde_json::Value),
15}
16
17#[derive(Clone, Debug)]
19pub enum PostToolAction {
20 Pass,
22 Modify(String),
24}
25
26pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
28pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
29
30#[derive(Clone)]
32pub struct ToolHook {
33 pub pattern: String,
35 pub pre: Option<PreToolHookFn>,
37 pub post: Option<PostToolHookFn>,
39}
40
41impl std::fmt::Debug for ToolHook {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("ToolHook")
44 .field("pattern", &self.pattern)
45 .field("has_pre", &self.pre.is_some())
46 .field("has_post", &self.post.is_some())
47 .finish()
48 }
49}
50
51thread_local! {
52 pub(super) static TOOL_HOOKS: RefCell<Vec<ToolHook>> = const { RefCell::new(Vec::new()) };
53}
54
55pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
56 if pattern == "*" {
57 return true;
58 }
59 if let Some(prefix) = pattern.strip_suffix('*') {
60 return name.starts_with(prefix);
61 }
62 if let Some(suffix) = pattern.strip_prefix('*') {
63 return name.ends_with(suffix);
64 }
65 pattern == name
66}
67
68pub fn register_tool_hook(hook: ToolHook) {
69 TOOL_HOOKS.with(|hooks| hooks.borrow_mut().push(hook));
70}
71
72pub fn clear_tool_hooks() {
73 TOOL_HOOKS.with(|hooks| hooks.borrow_mut().clear());
74}
75
76pub fn run_pre_tool_hooks(tool_name: &str, args: &serde_json::Value) -> PreToolAction {
78 TOOL_HOOKS.with(|hooks| {
79 let hooks = hooks.borrow();
80 let mut current_args = args.clone();
81 for hook in hooks.iter() {
82 if !glob_match(&hook.pattern, tool_name) {
83 continue;
84 }
85 if let Some(ref pre) = hook.pre {
86 match pre(tool_name, ¤t_args) {
87 PreToolAction::Allow => {}
88 PreToolAction::Deny(reason) => return PreToolAction::Deny(reason),
89 PreToolAction::Modify(new_args) => {
90 current_args = new_args;
91 }
92 }
93 }
94 }
95 if current_args != *args {
96 PreToolAction::Modify(current_args)
97 } else {
98 PreToolAction::Allow
99 }
100 })
101}
102
103pub fn run_post_tool_hooks(tool_name: &str, result: &str) -> String {
105 TOOL_HOOKS.with(|hooks| {
106 let hooks = hooks.borrow();
107 let mut current = result.to_string();
108 for hook in hooks.iter() {
109 if !glob_match(&hook.pattern, tool_name) {
110 continue;
111 }
112 if let Some(ref post) = hook.post {
113 match post(tool_name, ¤t) {
114 PostToolAction::Pass => {}
115 PostToolAction::Modify(new_result) => {
116 current = new_result;
117 }
118 }
119 }
120 }
121 current
122 })
123}