Skip to main content

harn_vm/orchestration/
hooks.rs

1//! Runtime lifecycle hooks — tool, agent-turn, and worker interception.
2
3use std::cell::RefCell;
4use std::rc::Rc;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9use crate::agent_events::WorkerEvent;
10use crate::value::{VmClosure, VmError, VmValue};
11
12/// Manifest / runtime hook event names.
13#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
14pub enum HookEvent {
15    #[serde(rename = "PreToolUse")]
16    PreToolUse,
17    #[serde(rename = "PostToolUse")]
18    PostToolUse,
19    #[serde(rename = "PreAgentTurn")]
20    PreAgentTurn,
21    #[serde(rename = "PostAgentTurn")]
22    PostAgentTurn,
23    #[serde(rename = "WorkerSpawned")]
24    WorkerSpawned,
25    #[serde(rename = "WorkerProgressed")]
26    WorkerProgressed,
27    #[serde(rename = "WorkerWaitingForInput")]
28    WorkerWaitingForInput,
29    #[serde(rename = "WorkerCompleted")]
30    WorkerCompleted,
31    #[serde(rename = "WorkerFailed")]
32    WorkerFailed,
33    #[serde(rename = "WorkerCancelled")]
34    WorkerCancelled,
35}
36
37impl HookEvent {
38    pub fn as_str(self) -> &'static str {
39        match self {
40            Self::PreToolUse => "PreToolUse",
41            Self::PostToolUse => "PostToolUse",
42            Self::PreAgentTurn => "PreAgentTurn",
43            Self::PostAgentTurn => "PostAgentTurn",
44            Self::WorkerSpawned => "WorkerSpawned",
45            Self::WorkerProgressed => "WorkerProgressed",
46            Self::WorkerWaitingForInput => "WorkerWaitingForInput",
47            Self::WorkerCompleted => "WorkerCompleted",
48            Self::WorkerFailed => "WorkerFailed",
49            Self::WorkerCancelled => "WorkerCancelled",
50        }
51    }
52
53    pub fn from_worker_event(event: WorkerEvent) -> Self {
54        match event {
55            WorkerEvent::WorkerSpawned => Self::WorkerSpawned,
56            WorkerEvent::WorkerProgressed => Self::WorkerProgressed,
57            WorkerEvent::WorkerWaitingForInput => Self::WorkerWaitingForInput,
58            WorkerEvent::WorkerCompleted => Self::WorkerCompleted,
59            WorkerEvent::WorkerFailed => Self::WorkerFailed,
60            WorkerEvent::WorkerCancelled => Self::WorkerCancelled,
61        }
62    }
63}
64
65/// Action returned by a PreToolUse hook.
66#[derive(Clone, Debug)]
67pub enum PreToolAction {
68    /// Allow the tool call to proceed unchanged.
69    Allow,
70    /// Deny the tool call with an explanation.
71    Deny(String),
72    /// Allow but replace the arguments.
73    Modify(serde_json::Value),
74}
75
76/// Action returned by a PostToolUse hook.
77#[derive(Clone, Debug)]
78pub enum PostToolAction {
79    /// Pass the result through unchanged.
80    Pass,
81    /// Replace the result text.
82    Modify(String),
83}
84
85/// Callback types for legacy tool lifecycle hooks.
86pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
87pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
88
89/// A registered tool hook with a name pattern and callbacks.
90#[derive(Clone)]
91pub struct ToolHook {
92    /// Glob-style pattern matched against tool names (e.g. `"*"`, `"exec*"`, `"read_file"`).
93    pub pattern: String,
94    /// Called before tool execution. Return `Deny` to reject, `Modify` to rewrite args.
95    pub pre: Option<PreToolHookFn>,
96    /// Called after tool execution with the result text. Return `Modify` to rewrite.
97    pub post: Option<PostToolHookFn>,
98}
99
100impl std::fmt::Debug for ToolHook {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct("ToolHook")
103            .field("pattern", &self.pattern)
104            .field("has_pre", &self.pre.is_some())
105            .field("has_post", &self.post.is_some())
106            .finish()
107    }
108}
109
110#[derive(Clone)]
111enum PatternMatcher {
112    ToolNameGlob(String),
113    EventExpression(String),
114}
115
116impl std::fmt::Debug for PatternMatcher {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
120            Self::EventExpression(pattern) => {
121                f.debug_tuple("EventExpression").field(pattern).finish()
122            }
123        }
124    }
125}
126
127#[derive(Clone)]
128enum RuntimeHookHandler {
129    NativePreTool(PreToolHookFn),
130    NativePostTool(PostToolHookFn),
131    Vm {
132        handler_name: String,
133        closure: Rc<VmClosure>,
134    },
135}
136
137impl std::fmt::Debug for RuntimeHookHandler {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        match self {
140            Self::NativePreTool(_) => f.write_str("NativePreTool(..)"),
141            Self::NativePostTool(_) => f.write_str("NativePostTool(..)"),
142            Self::Vm { handler_name, .. } => f
143                .debug_struct("Vm")
144                .field("handler_name", handler_name)
145                .finish(),
146        }
147    }
148}
149
150#[derive(Clone, Debug)]
151struct RuntimeHook {
152    event: HookEvent,
153    matcher: PatternMatcher,
154    handler: RuntimeHookHandler,
155}
156
157thread_local! {
158    static RUNTIME_HOOKS: RefCell<Vec<RuntimeHook>> = const { RefCell::new(Vec::new()) };
159}
160
161pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
162    if pattern == "*" {
163        return true;
164    }
165    if let Some(prefix) = pattern.strip_suffix('*') {
166        return name.starts_with(prefix);
167    }
168    if let Some(suffix) = pattern.strip_prefix('*') {
169        return name.ends_with(suffix);
170    }
171    pattern == name
172}
173
174pub fn register_tool_hook(hook: ToolHook) {
175    if let Some(pre) = hook.pre {
176        RUNTIME_HOOKS.with(|hooks| {
177            hooks.borrow_mut().push(RuntimeHook {
178                event: HookEvent::PreToolUse,
179                matcher: PatternMatcher::ToolNameGlob(hook.pattern.clone()),
180                handler: RuntimeHookHandler::NativePreTool(pre),
181            });
182        });
183    }
184    if let Some(post) = hook.post {
185        RUNTIME_HOOKS.with(|hooks| {
186            hooks.borrow_mut().push(RuntimeHook {
187                event: HookEvent::PostToolUse,
188                matcher: PatternMatcher::ToolNameGlob(hook.pattern),
189                handler: RuntimeHookHandler::NativePostTool(post),
190            });
191        });
192    }
193}
194
195pub fn register_vm_hook(
196    event: HookEvent,
197    pattern: impl Into<String>,
198    handler_name: impl Into<String>,
199    closure: Rc<VmClosure>,
200) {
201    RUNTIME_HOOKS.with(|hooks| {
202        hooks.borrow_mut().push(RuntimeHook {
203            event,
204            matcher: PatternMatcher::EventExpression(pattern.into()),
205            handler: RuntimeHookHandler::Vm {
206                handler_name: handler_name.into(),
207                closure,
208            },
209        });
210    });
211}
212
213pub fn clear_tool_hooks() {
214    RUNTIME_HOOKS.with(|hooks| {
215        hooks
216            .borrow_mut()
217            .retain(|hook| !matches!(hook.event, HookEvent::PreToolUse | HookEvent::PostToolUse));
218    });
219}
220
221pub fn clear_runtime_hooks() {
222    RUNTIME_HOOKS.with(|hooks| hooks.borrow_mut().clear());
223    super::clear_command_policies();
224}
225
226fn value_at_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
227    let mut current = value;
228    for segment in path.split('.') {
229        let serde_json::Value::Object(map) = current else {
230            return None;
231        };
232        current = map.get(segment)?;
233    }
234    Some(current)
235}
236
237fn value_truthy(value: &serde_json::Value) -> bool {
238    match value {
239        serde_json::Value::Null => false,
240        serde_json::Value::Bool(value) => *value,
241        serde_json::Value::Number(value) => value
242            .as_i64()
243            .map(|number| number != 0)
244            .or_else(|| value.as_u64().map(|number| number != 0))
245            .or_else(|| value.as_f64().map(|number| number != 0.0))
246            .unwrap_or(false),
247        serde_json::Value::String(value) => !value.is_empty(),
248        serde_json::Value::Array(values) => !values.is_empty(),
249        serde_json::Value::Object(values) => !values.is_empty(),
250    }
251}
252
253fn value_to_pattern_string(value: Option<&serde_json::Value>) -> String {
254    match value {
255        Some(serde_json::Value::String(text)) => text.clone(),
256        Some(other) => other.to_string(),
257        None => String::new(),
258    }
259}
260
261fn strip_quoted(value: &str) -> &str {
262    value
263        .trim()
264        .strip_prefix('"')
265        .and_then(|text| text.strip_suffix('"'))
266        .or_else(|| {
267            value
268                .trim()
269                .strip_prefix('\'')
270                .and_then(|text| text.strip_suffix('\''))
271        })
272        .unwrap_or(value.trim())
273}
274
275fn expression_matches(pattern: &str, payload: &serde_json::Value) -> bool {
276    let pattern = pattern.trim();
277    if pattern.is_empty() || pattern == "*" {
278        return true;
279    }
280    if let Some((lhs, rhs)) = pattern.split_once("=~") {
281        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
282        let regex = strip_quoted(rhs);
283        return Regex::new(regex).is_ok_and(|compiled| compiled.is_match(&value));
284    }
285    if let Some((lhs, rhs)) = pattern.split_once("==") {
286        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
287        return value == strip_quoted(rhs);
288    }
289    if let Some((lhs, rhs)) = pattern.split_once("!=") {
290        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
291        return value != strip_quoted(rhs);
292    }
293    if pattern.contains('.') {
294        return value_at_path(payload, pattern).is_some_and(value_truthy);
295    }
296    glob_match(
297        pattern,
298        &value_to_pattern_string(value_at_path(payload, "tool.name")),
299    )
300}
301
302fn hook_matches(hook: &RuntimeHook, tool_name: Option<&str>, payload: &serde_json::Value) -> bool {
303    match &hook.matcher {
304        PatternMatcher::ToolNameGlob(pattern) => {
305            tool_name.is_some_and(|candidate| glob_match(pattern, candidate))
306        }
307        PatternMatcher::EventExpression(pattern) => expression_matches(pattern, payload),
308    }
309}
310
311async fn invoke_vm_hook(
312    closure: &Rc<VmClosure>,
313    payload: &serde_json::Value,
314) -> Result<VmValue, VmError> {
315    let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
316        return Err(VmError::Runtime(
317            "runtime hook requires an async builtin VM context".to_string(),
318        ));
319    };
320    let arg = crate::stdlib::json_to_vm_value(payload);
321    vm.call_closure_pub(closure, &[arg]).await
322}
323
324fn parse_pre_tool_result(value: VmValue) -> Result<PreToolAction, VmError> {
325    match value {
326        VmValue::Nil => Ok(PreToolAction::Allow),
327        VmValue::Dict(map) => {
328            if let Some(reason) = map.get("deny") {
329                return Ok(PreToolAction::Deny(reason.display()));
330            }
331            if let Some(args) = map.get("args") {
332                return Ok(PreToolAction::Modify(crate::llm::vm_value_to_json(args)));
333            }
334            Ok(PreToolAction::Allow)
335        }
336        other => Err(VmError::Runtime(format!(
337            "PreToolUse hook must return nil or {{deny, args}}, got {}",
338            other.type_name()
339        ))),
340    }
341}
342
343fn parse_post_tool_result(value: VmValue) -> Result<PostToolAction, VmError> {
344    match value {
345        VmValue::Nil => Ok(PostToolAction::Pass),
346        VmValue::String(text) => Ok(PostToolAction::Modify(text.to_string())),
347        VmValue::Dict(map) => {
348            if let Some(result) = map.get("result") {
349                return Ok(PostToolAction::Modify(result.display()));
350            }
351            Ok(PostToolAction::Pass)
352        }
353        other => Err(VmError::Runtime(format!(
354            "PostToolUse hook must return nil, string, or {{result}}, got {}",
355            other.type_name()
356        ))),
357    }
358}
359
360/// Run all matching PreToolUse hooks. Returns the final action.
361pub async fn run_pre_tool_hooks(
362    tool_name: &str,
363    args: &serde_json::Value,
364) -> Result<PreToolAction, VmError> {
365    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
366    let mut current_args = args.clone();
367    for hook in hooks
368        .iter()
369        .filter(|hook| hook.event == HookEvent::PreToolUse)
370    {
371        let payload = serde_json::json!({
372            "event": HookEvent::PreToolUse.as_str(),
373            "tool": {
374                "name": tool_name,
375                "args": current_args.clone(),
376            },
377        });
378        if !hook_matches(hook, Some(tool_name), &payload) {
379            continue;
380        }
381        let action = match &hook.handler {
382            RuntimeHookHandler::NativePreTool(pre) => pre(tool_name, &current_args),
383            RuntimeHookHandler::Vm { closure, .. } => {
384                parse_pre_tool_result(invoke_vm_hook(closure, &payload).await?)?
385            }
386            RuntimeHookHandler::NativePostTool(_) => continue,
387        };
388        match action {
389            PreToolAction::Allow => {}
390            PreToolAction::Deny(reason) => return Ok(PreToolAction::Deny(reason)),
391            PreToolAction::Modify(new_args) => {
392                current_args = new_args;
393            }
394        }
395    }
396    if current_args != *args {
397        Ok(PreToolAction::Modify(current_args))
398    } else {
399        Ok(PreToolAction::Allow)
400    }
401}
402
403/// Run all matching PostToolUse hooks. Returns the (possibly modified) result.
404pub async fn run_post_tool_hooks(
405    tool_name: &str,
406    args: &serde_json::Value,
407    result: &str,
408) -> Result<String, VmError> {
409    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
410    let mut current = result.to_string();
411    for hook in hooks
412        .iter()
413        .filter(|hook| hook.event == HookEvent::PostToolUse)
414    {
415        let payload = serde_json::json!({
416            "event": HookEvent::PostToolUse.as_str(),
417            "tool": {
418                "name": tool_name,
419                "args": args,
420            },
421            "result": {
422                "text": current.clone(),
423            },
424        });
425        if !hook_matches(hook, Some(tool_name), &payload) {
426            continue;
427        }
428        let action = match &hook.handler {
429            RuntimeHookHandler::NativePostTool(post) => post(tool_name, &current),
430            RuntimeHookHandler::Vm { closure, .. } => {
431                parse_post_tool_result(invoke_vm_hook(closure, &payload).await?)?
432            }
433            RuntimeHookHandler::NativePreTool(_) => continue,
434        };
435        match action {
436            PostToolAction::Pass => {}
437            PostToolAction::Modify(new_result) => {
438                current = new_result;
439            }
440        }
441    }
442    Ok(current)
443}
444
445pub async fn run_lifecycle_hooks(
446    event: HookEvent,
447    payload: &serde_json::Value,
448) -> Result<(), VmError> {
449    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
450    for hook in hooks.iter().filter(|hook| hook.event == event) {
451        if !hook_matches(hook, None, payload) {
452            continue;
453        }
454        if let RuntimeHookHandler::Vm { closure, .. } = &hook.handler {
455            let _ = invoke_vm_hook(closure, payload).await?;
456        }
457    }
458    Ok(())
459}