Skip to main content

harn_vm/orchestration/
hooks.rs

1//! Runtime lifecycle hooks — tool, agent-turn, and worker interception.
2
3use std::cell::RefCell;
4use std::rc::Rc;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9use crate::agent_events::WorkerEvent;
10use crate::value::{VmClosure, VmError, VmValue};
11
12/// Manifest / runtime hook event names.
13#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
14pub enum HookEvent {
15    #[serde(rename = "PreToolUse")]
16    PreToolUse,
17    #[serde(rename = "PostToolUse")]
18    PostToolUse,
19    #[serde(rename = "PreAgentTurn")]
20    PreAgentTurn,
21    #[serde(rename = "PostAgentTurn")]
22    PostAgentTurn,
23    #[serde(rename = "WorkerSpawned")]
24    WorkerSpawned,
25    #[serde(rename = "WorkerCompleted")]
26    WorkerCompleted,
27    #[serde(rename = "WorkerFailed")]
28    WorkerFailed,
29    #[serde(rename = "WorkerCancelled")]
30    WorkerCancelled,
31}
32
33impl HookEvent {
34    pub fn as_str(self) -> &'static str {
35        match self {
36            Self::PreToolUse => "PreToolUse",
37            Self::PostToolUse => "PostToolUse",
38            Self::PreAgentTurn => "PreAgentTurn",
39            Self::PostAgentTurn => "PostAgentTurn",
40            Self::WorkerSpawned => "WorkerSpawned",
41            Self::WorkerCompleted => "WorkerCompleted",
42            Self::WorkerFailed => "WorkerFailed",
43            Self::WorkerCancelled => "WorkerCancelled",
44        }
45    }
46
47    pub fn from_worker_event(event: WorkerEvent) -> Self {
48        match event {
49            WorkerEvent::WorkerSpawned => Self::WorkerSpawned,
50            WorkerEvent::WorkerCompleted => Self::WorkerCompleted,
51            WorkerEvent::WorkerFailed => Self::WorkerFailed,
52            WorkerEvent::WorkerCancelled => Self::WorkerCancelled,
53        }
54    }
55}
56
57/// Action returned by a PreToolUse hook.
58#[derive(Clone, Debug)]
59pub enum PreToolAction {
60    /// Allow the tool call to proceed unchanged.
61    Allow,
62    /// Deny the tool call with an explanation.
63    Deny(String),
64    /// Allow but replace the arguments.
65    Modify(serde_json::Value),
66}
67
68/// Action returned by a PostToolUse hook.
69#[derive(Clone, Debug)]
70pub enum PostToolAction {
71    /// Pass the result through unchanged.
72    Pass,
73    /// Replace the result text.
74    Modify(String),
75}
76
77/// Callback types for legacy tool lifecycle hooks.
78pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
79pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
80
81/// A registered tool hook with a name pattern and callbacks.
82#[derive(Clone)]
83pub struct ToolHook {
84    /// Glob-style pattern matched against tool names (e.g. `"*"`, `"exec*"`, `"read_file"`).
85    pub pattern: String,
86    /// Called before tool execution. Return `Deny` to reject, `Modify` to rewrite args.
87    pub pre: Option<PreToolHookFn>,
88    /// Called after tool execution with the result text. Return `Modify` to rewrite.
89    pub post: Option<PostToolHookFn>,
90}
91
92impl std::fmt::Debug for ToolHook {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("ToolHook")
95            .field("pattern", &self.pattern)
96            .field("has_pre", &self.pre.is_some())
97            .field("has_post", &self.post.is_some())
98            .finish()
99    }
100}
101
102#[derive(Clone)]
103enum PatternMatcher {
104    ToolNameGlob(String),
105    EventExpression(String),
106}
107
108impl std::fmt::Debug for PatternMatcher {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        match self {
111            Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
112            Self::EventExpression(pattern) => {
113                f.debug_tuple("EventExpression").field(pattern).finish()
114            }
115        }
116    }
117}
118
119#[derive(Clone)]
120enum RuntimeHookHandler {
121    NativePreTool(PreToolHookFn),
122    NativePostTool(PostToolHookFn),
123    Vm {
124        handler_name: String,
125        closure: Rc<VmClosure>,
126    },
127}
128
129impl std::fmt::Debug for RuntimeHookHandler {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            Self::NativePreTool(_) => f.write_str("NativePreTool(..)"),
133            Self::NativePostTool(_) => f.write_str("NativePostTool(..)"),
134            Self::Vm { handler_name, .. } => f
135                .debug_struct("Vm")
136                .field("handler_name", handler_name)
137                .finish(),
138        }
139    }
140}
141
142#[derive(Clone, Debug)]
143struct RuntimeHook {
144    event: HookEvent,
145    matcher: PatternMatcher,
146    handler: RuntimeHookHandler,
147}
148
149thread_local! {
150    static RUNTIME_HOOKS: RefCell<Vec<RuntimeHook>> = const { RefCell::new(Vec::new()) };
151}
152
153pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
154    if pattern == "*" {
155        return true;
156    }
157    if let Some(prefix) = pattern.strip_suffix('*') {
158        return name.starts_with(prefix);
159    }
160    if let Some(suffix) = pattern.strip_prefix('*') {
161        return name.ends_with(suffix);
162    }
163    pattern == name
164}
165
166pub fn register_tool_hook(hook: ToolHook) {
167    if let Some(pre) = hook.pre {
168        RUNTIME_HOOKS.with(|hooks| {
169            hooks.borrow_mut().push(RuntimeHook {
170                event: HookEvent::PreToolUse,
171                matcher: PatternMatcher::ToolNameGlob(hook.pattern.clone()),
172                handler: RuntimeHookHandler::NativePreTool(pre),
173            });
174        });
175    }
176    if let Some(post) = hook.post {
177        RUNTIME_HOOKS.with(|hooks| {
178            hooks.borrow_mut().push(RuntimeHook {
179                event: HookEvent::PostToolUse,
180                matcher: PatternMatcher::ToolNameGlob(hook.pattern),
181                handler: RuntimeHookHandler::NativePostTool(post),
182            });
183        });
184    }
185}
186
187pub fn register_vm_hook(
188    event: HookEvent,
189    pattern: impl Into<String>,
190    handler_name: impl Into<String>,
191    closure: Rc<VmClosure>,
192) {
193    RUNTIME_HOOKS.with(|hooks| {
194        hooks.borrow_mut().push(RuntimeHook {
195            event,
196            matcher: PatternMatcher::EventExpression(pattern.into()),
197            handler: RuntimeHookHandler::Vm {
198                handler_name: handler_name.into(),
199                closure,
200            },
201        });
202    });
203}
204
205pub fn clear_tool_hooks() {
206    RUNTIME_HOOKS.with(|hooks| {
207        hooks
208            .borrow_mut()
209            .retain(|hook| !matches!(hook.event, HookEvent::PreToolUse | HookEvent::PostToolUse));
210    });
211}
212
213pub fn clear_runtime_hooks() {
214    RUNTIME_HOOKS.with(|hooks| hooks.borrow_mut().clear());
215}
216
217fn value_at_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
218    let mut current = value;
219    for segment in path.split('.') {
220        let serde_json::Value::Object(map) = current else {
221            return None;
222        };
223        current = map.get(segment)?;
224    }
225    Some(current)
226}
227
228fn value_truthy(value: &serde_json::Value) -> bool {
229    match value {
230        serde_json::Value::Null => false,
231        serde_json::Value::Bool(value) => *value,
232        serde_json::Value::Number(value) => value
233            .as_i64()
234            .map(|number| number != 0)
235            .or_else(|| value.as_u64().map(|number| number != 0))
236            .or_else(|| value.as_f64().map(|number| number != 0.0))
237            .unwrap_or(false),
238        serde_json::Value::String(value) => !value.is_empty(),
239        serde_json::Value::Array(values) => !values.is_empty(),
240        serde_json::Value::Object(values) => !values.is_empty(),
241    }
242}
243
244fn value_to_pattern_string(value: Option<&serde_json::Value>) -> String {
245    match value {
246        Some(serde_json::Value::String(text)) => text.clone(),
247        Some(other) => other.to_string(),
248        None => String::new(),
249    }
250}
251
252fn strip_quoted(value: &str) -> &str {
253    value
254        .trim()
255        .strip_prefix('"')
256        .and_then(|text| text.strip_suffix('"'))
257        .or_else(|| {
258            value
259                .trim()
260                .strip_prefix('\'')
261                .and_then(|text| text.strip_suffix('\''))
262        })
263        .unwrap_or(value.trim())
264}
265
266fn expression_matches(pattern: &str, payload: &serde_json::Value) -> bool {
267    let pattern = pattern.trim();
268    if pattern.is_empty() || pattern == "*" {
269        return true;
270    }
271    if let Some((lhs, rhs)) = pattern.split_once("=~") {
272        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
273        let regex = strip_quoted(rhs);
274        return Regex::new(regex).is_ok_and(|compiled| compiled.is_match(&value));
275    }
276    if let Some((lhs, rhs)) = pattern.split_once("==") {
277        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
278        return value == strip_quoted(rhs);
279    }
280    if let Some((lhs, rhs)) = pattern.split_once("!=") {
281        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
282        return value != strip_quoted(rhs);
283    }
284    if pattern.contains('.') {
285        return value_at_path(payload, pattern).is_some_and(value_truthy);
286    }
287    glob_match(
288        pattern,
289        &value_to_pattern_string(value_at_path(payload, "tool.name")),
290    )
291}
292
293fn hook_matches(hook: &RuntimeHook, tool_name: Option<&str>, payload: &serde_json::Value) -> bool {
294    match &hook.matcher {
295        PatternMatcher::ToolNameGlob(pattern) => {
296            tool_name.is_some_and(|candidate| glob_match(pattern, candidate))
297        }
298        PatternMatcher::EventExpression(pattern) => expression_matches(pattern, payload),
299    }
300}
301
302async fn invoke_vm_hook(
303    closure: &Rc<VmClosure>,
304    payload: &serde_json::Value,
305) -> Result<VmValue, VmError> {
306    let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
307        return Err(VmError::Runtime(
308            "runtime hook requires an async builtin VM context".to_string(),
309        ));
310    };
311    let arg = crate::stdlib::json_to_vm_value(payload);
312    vm.call_closure_pub(closure, &[arg]).await
313}
314
315fn parse_pre_tool_result(value: VmValue) -> Result<PreToolAction, VmError> {
316    match value {
317        VmValue::Nil => Ok(PreToolAction::Allow),
318        VmValue::Dict(map) => {
319            if let Some(reason) = map.get("deny") {
320                return Ok(PreToolAction::Deny(reason.display()));
321            }
322            if let Some(args) = map.get("args") {
323                return Ok(PreToolAction::Modify(crate::llm::vm_value_to_json(args)));
324            }
325            Ok(PreToolAction::Allow)
326        }
327        other => Err(VmError::Runtime(format!(
328            "PreToolUse hook must return nil or {{deny, args}}, got {}",
329            other.type_name()
330        ))),
331    }
332}
333
334fn parse_post_tool_result(value: VmValue) -> Result<PostToolAction, VmError> {
335    match value {
336        VmValue::Nil => Ok(PostToolAction::Pass),
337        VmValue::String(text) => Ok(PostToolAction::Modify(text.to_string())),
338        VmValue::Dict(map) => {
339            if let Some(result) = map.get("result") {
340                return Ok(PostToolAction::Modify(result.display()));
341            }
342            Ok(PostToolAction::Pass)
343        }
344        other => Err(VmError::Runtime(format!(
345            "PostToolUse hook must return nil, string, or {{result}}, got {}",
346            other.type_name()
347        ))),
348    }
349}
350
351/// Run all matching PreToolUse hooks. Returns the final action.
352pub async fn run_pre_tool_hooks(
353    tool_name: &str,
354    args: &serde_json::Value,
355) -> Result<PreToolAction, VmError> {
356    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
357    let mut current_args = args.clone();
358    for hook in hooks
359        .iter()
360        .filter(|hook| hook.event == HookEvent::PreToolUse)
361    {
362        let payload = serde_json::json!({
363            "event": HookEvent::PreToolUse.as_str(),
364            "tool": {
365                "name": tool_name,
366                "args": current_args.clone(),
367            },
368        });
369        if !hook_matches(hook, Some(tool_name), &payload) {
370            continue;
371        }
372        let action = match &hook.handler {
373            RuntimeHookHandler::NativePreTool(pre) => pre(tool_name, &current_args),
374            RuntimeHookHandler::Vm { closure, .. } => {
375                parse_pre_tool_result(invoke_vm_hook(closure, &payload).await?)?
376            }
377            RuntimeHookHandler::NativePostTool(_) => continue,
378        };
379        match action {
380            PreToolAction::Allow => {}
381            PreToolAction::Deny(reason) => return Ok(PreToolAction::Deny(reason)),
382            PreToolAction::Modify(new_args) => {
383                current_args = new_args;
384            }
385        }
386    }
387    if current_args != *args {
388        Ok(PreToolAction::Modify(current_args))
389    } else {
390        Ok(PreToolAction::Allow)
391    }
392}
393
394/// Run all matching PostToolUse hooks. Returns the (possibly modified) result.
395pub async fn run_post_tool_hooks(
396    tool_name: &str,
397    args: &serde_json::Value,
398    result: &str,
399) -> Result<String, VmError> {
400    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
401    let mut current = result.to_string();
402    for hook in hooks
403        .iter()
404        .filter(|hook| hook.event == HookEvent::PostToolUse)
405    {
406        let payload = serde_json::json!({
407            "event": HookEvent::PostToolUse.as_str(),
408            "tool": {
409                "name": tool_name,
410                "args": args,
411            },
412            "result": {
413                "text": current.clone(),
414            },
415        });
416        if !hook_matches(hook, Some(tool_name), &payload) {
417            continue;
418        }
419        let action = match &hook.handler {
420            RuntimeHookHandler::NativePostTool(post) => post(tool_name, &current),
421            RuntimeHookHandler::Vm { closure, .. } => {
422                parse_post_tool_result(invoke_vm_hook(closure, &payload).await?)?
423            }
424            RuntimeHookHandler::NativePreTool(_) => continue,
425        };
426        match action {
427            PostToolAction::Pass => {}
428            PostToolAction::Modify(new_result) => {
429                current = new_result;
430            }
431        }
432    }
433    Ok(current)
434}
435
436pub async fn run_lifecycle_hooks(
437    event: HookEvent,
438    payload: &serde_json::Value,
439) -> Result<(), VmError> {
440    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
441    for hook in hooks.iter().filter(|hook| hook.event == event) {
442        if !hook_matches(hook, None, payload) {
443            continue;
444        }
445        if let RuntimeHookHandler::Vm { closure, .. } = &hook.handler {
446            let _ = invoke_vm_hook(closure, payload).await?;
447        }
448    }
449    Ok(())
450}