Skip to main content

harn_vm/orchestration/
hooks.rs

1//! Runtime lifecycle hooks — tool, agent-turn, and worker interception.
2
3use std::cell::RefCell;
4use std::rc::Rc;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9use crate::agent_events::WorkerEvent;
10use crate::value::{VmClosure, VmError, VmValue};
11
12/// Manifest / runtime hook event names.
13#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
14pub enum HookEvent {
15    #[serde(rename = "PreToolUse")]
16    PreToolUse,
17    #[serde(rename = "PostToolUse")]
18    PostToolUse,
19    #[serde(rename = "PreAgentTurn")]
20    PreAgentTurn,
21    #[serde(rename = "PostAgentTurn")]
22    PostAgentTurn,
23    #[serde(rename = "WorkerSpawned")]
24    WorkerSpawned,
25    #[serde(rename = "WorkerProgressed")]
26    WorkerProgressed,
27    #[serde(rename = "WorkerWaitingForInput")]
28    WorkerWaitingForInput,
29    #[serde(rename = "WorkerCompleted")]
30    WorkerCompleted,
31    #[serde(rename = "WorkerFailed")]
32    WorkerFailed,
33    #[serde(rename = "WorkerCancelled")]
34    WorkerCancelled,
35    #[serde(rename = "PreStep")]
36    PreStep,
37    #[serde(rename = "PostStep")]
38    PostStep,
39    #[serde(rename = "OnBudgetThreshold")]
40    OnBudgetThreshold,
41    #[serde(rename = "OnApprovalRequested")]
42    OnApprovalRequested,
43    #[serde(rename = "OnHandoffEmitted")]
44    OnHandoffEmitted,
45    #[serde(rename = "OnPersonaPaused")]
46    OnPersonaPaused,
47    #[serde(rename = "OnPersonaResumed")]
48    OnPersonaResumed,
49}
50
51impl HookEvent {
52    pub fn as_str(self) -> &'static str {
53        match self {
54            Self::PreToolUse => "PreToolUse",
55            Self::PostToolUse => "PostToolUse",
56            Self::PreAgentTurn => "PreAgentTurn",
57            Self::PostAgentTurn => "PostAgentTurn",
58            Self::WorkerSpawned => "WorkerSpawned",
59            Self::WorkerProgressed => "WorkerProgressed",
60            Self::WorkerWaitingForInput => "WorkerWaitingForInput",
61            Self::WorkerCompleted => "WorkerCompleted",
62            Self::WorkerFailed => "WorkerFailed",
63            Self::WorkerCancelled => "WorkerCancelled",
64            Self::PreStep => "PreStep",
65            Self::PostStep => "PostStep",
66            Self::OnBudgetThreshold => "OnBudgetThreshold",
67            Self::OnApprovalRequested => "OnApprovalRequested",
68            Self::OnHandoffEmitted => "OnHandoffEmitted",
69            Self::OnPersonaPaused => "OnPersonaPaused",
70            Self::OnPersonaResumed => "OnPersonaResumed",
71        }
72    }
73
74    pub fn from_worker_event(event: WorkerEvent) -> Self {
75        match event {
76            WorkerEvent::WorkerSpawned => Self::WorkerSpawned,
77            WorkerEvent::WorkerProgressed => Self::WorkerProgressed,
78            WorkerEvent::WorkerWaitingForInput => Self::WorkerWaitingForInput,
79            WorkerEvent::WorkerCompleted => Self::WorkerCompleted,
80            WorkerEvent::WorkerFailed => Self::WorkerFailed,
81            WorkerEvent::WorkerCancelled => Self::WorkerCancelled,
82        }
83    }
84}
85
86/// Action returned by a PreToolUse hook.
87#[derive(Clone, Debug)]
88pub enum PreToolAction {
89    /// Allow the tool call to proceed unchanged.
90    Allow,
91    /// Deny the tool call with an explanation.
92    Deny(String),
93    /// Allow but replace the arguments.
94    Modify(serde_json::Value),
95}
96
97/// Action returned by a PostToolUse hook.
98#[derive(Clone, Debug)]
99pub enum PostToolAction {
100    /// Pass the result through unchanged.
101    Pass,
102    /// Replace the result text.
103    Modify(String),
104}
105
106/// Callback types for legacy tool lifecycle hooks.
107pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
108pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
109
110/// A registered tool hook with a name pattern and callbacks.
111#[derive(Clone)]
112pub struct ToolHook {
113    /// Glob-style pattern matched against tool names (e.g. `"*"`, `"exec*"`, `"read_file"`).
114    pub pattern: String,
115    /// Called before tool execution. Return `Deny` to reject, `Modify` to rewrite args.
116    pub pre: Option<PreToolHookFn>,
117    /// Called after tool execution with the result text. Return `Modify` to rewrite.
118    pub post: Option<PostToolHookFn>,
119}
120
121impl std::fmt::Debug for ToolHook {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("ToolHook")
124            .field("pattern", &self.pattern)
125            .field("has_pre", &self.pre.is_some())
126            .field("has_post", &self.post.is_some())
127            .finish()
128    }
129}
130
131#[derive(Clone)]
132enum PatternMatcher {
133    ToolNameGlob(String),
134    EventExpression(String),
135}
136
137impl std::fmt::Debug for PatternMatcher {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        match self {
140            Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
141            Self::EventExpression(pattern) => {
142                f.debug_tuple("EventExpression").field(pattern).finish()
143            }
144        }
145    }
146}
147
148#[derive(Clone)]
149enum RuntimeHookHandler {
150    NativePreTool(PreToolHookFn),
151    NativePostTool(PostToolHookFn),
152    Vm {
153        handler_name: String,
154        closure: Rc<VmClosure>,
155    },
156}
157
158impl std::fmt::Debug for RuntimeHookHandler {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        match self {
161            Self::NativePreTool(_) => f.write_str("NativePreTool(..)"),
162            Self::NativePostTool(_) => f.write_str("NativePostTool(..)"),
163            Self::Vm { handler_name, .. } => f
164                .debug_struct("Vm")
165                .field("handler_name", handler_name)
166                .finish(),
167        }
168    }
169}
170
171#[derive(Clone, Debug)]
172struct RuntimeHook {
173    event: HookEvent,
174    matcher: PatternMatcher,
175    handler: RuntimeHookHandler,
176}
177
178#[derive(Clone, Debug)]
179pub struct VmLifecycleHookInvocation {
180    pub closure: Rc<VmClosure>,
181}
182
183thread_local! {
184    static RUNTIME_HOOKS: RefCell<Vec<RuntimeHook>> = const { RefCell::new(Vec::new()) };
185}
186
187pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
188    if pattern == "*" {
189        return true;
190    }
191    if let Some(prefix) = pattern.strip_suffix('*') {
192        return name.starts_with(prefix);
193    }
194    if let Some(suffix) = pattern.strip_prefix('*') {
195        return name.ends_with(suffix);
196    }
197    pattern == name
198}
199
200pub fn register_tool_hook(hook: ToolHook) {
201    if let Some(pre) = hook.pre {
202        RUNTIME_HOOKS.with(|hooks| {
203            hooks.borrow_mut().push(RuntimeHook {
204                event: HookEvent::PreToolUse,
205                matcher: PatternMatcher::ToolNameGlob(hook.pattern.clone()),
206                handler: RuntimeHookHandler::NativePreTool(pre),
207            });
208        });
209    }
210    if let Some(post) = hook.post {
211        RUNTIME_HOOKS.with(|hooks| {
212            hooks.borrow_mut().push(RuntimeHook {
213                event: HookEvent::PostToolUse,
214                matcher: PatternMatcher::ToolNameGlob(hook.pattern),
215                handler: RuntimeHookHandler::NativePostTool(post),
216            });
217        });
218    }
219}
220
221pub fn register_vm_hook(
222    event: HookEvent,
223    pattern: impl Into<String>,
224    handler_name: impl Into<String>,
225    closure: Rc<VmClosure>,
226) {
227    RUNTIME_HOOKS.with(|hooks| {
228        hooks.borrow_mut().push(RuntimeHook {
229            event,
230            matcher: PatternMatcher::EventExpression(pattern.into()),
231            handler: RuntimeHookHandler::Vm {
232                handler_name: handler_name.into(),
233                closure,
234            },
235        });
236    });
237}
238
239pub fn clear_tool_hooks() {
240    RUNTIME_HOOKS.with(|hooks| {
241        hooks
242            .borrow_mut()
243            .retain(|hook| !matches!(hook.event, HookEvent::PreToolUse | HookEvent::PostToolUse));
244    });
245}
246
247pub fn clear_runtime_hooks() {
248    RUNTIME_HOOKS.with(|hooks| hooks.borrow_mut().clear());
249    super::clear_command_policies();
250}
251
252fn value_at_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
253    let mut current = value;
254    for segment in path.split('.') {
255        let serde_json::Value::Object(map) = current else {
256            return None;
257        };
258        current = map.get(segment)?;
259    }
260    Some(current)
261}
262
263fn value_truthy(value: &serde_json::Value) -> bool {
264    match value {
265        serde_json::Value::Null => false,
266        serde_json::Value::Bool(value) => *value,
267        serde_json::Value::Number(value) => value
268            .as_i64()
269            .map(|number| number != 0)
270            .or_else(|| value.as_u64().map(|number| number != 0))
271            .or_else(|| value.as_f64().map(|number| number != 0.0))
272            .unwrap_or(false),
273        serde_json::Value::String(value) => !value.is_empty(),
274        serde_json::Value::Array(values) => !values.is_empty(),
275        serde_json::Value::Object(values) => !values.is_empty(),
276    }
277}
278
279fn value_to_pattern_string(value: Option<&serde_json::Value>) -> String {
280    match value {
281        Some(serde_json::Value::String(text)) => text.clone(),
282        Some(other) => other.to_string(),
283        None => String::new(),
284    }
285}
286
287fn strip_quoted(value: &str) -> &str {
288    value
289        .trim()
290        .strip_prefix('"')
291        .and_then(|text| text.strip_suffix('"'))
292        .or_else(|| {
293            value
294                .trim()
295                .strip_prefix('\'')
296                .and_then(|text| text.strip_suffix('\''))
297        })
298        .unwrap_or(value.trim())
299}
300
301fn expression_matches(pattern: &str, payload: &serde_json::Value) -> bool {
302    let pattern = pattern.trim();
303    if pattern.is_empty() || pattern == "*" {
304        return true;
305    }
306    if let Some(target) = value_at_path(payload, "target").and_then(serde_json::Value::as_str) {
307        if glob_match(pattern, target) {
308            return true;
309        }
310    }
311    if let Some((lhs, rhs)) = pattern.split_once("=~") {
312        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
313        let regex = strip_quoted(rhs);
314        return Regex::new(regex).is_ok_and(|compiled| compiled.is_match(&value));
315    }
316    if let Some((lhs, rhs)) = pattern.split_once("==") {
317        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
318        return value == strip_quoted(rhs);
319    }
320    if let Some((lhs, rhs)) = pattern.split_once("!=") {
321        let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
322        return value != strip_quoted(rhs);
323    }
324    if pattern.contains('.') {
325        return value_at_path(payload, pattern).is_some_and(value_truthy);
326    }
327    glob_match(
328        pattern,
329        &value_to_pattern_string(value_at_path(payload, "tool.name")),
330    )
331}
332
333fn hook_matches(hook: &RuntimeHook, tool_name: Option<&str>, payload: &serde_json::Value) -> bool {
334    match &hook.matcher {
335        PatternMatcher::ToolNameGlob(pattern) => {
336            tool_name.is_some_and(|candidate| glob_match(pattern, candidate))
337        }
338        PatternMatcher::EventExpression(pattern) => expression_matches(pattern, payload),
339    }
340}
341
342async fn invoke_vm_hook(
343    closure: &Rc<VmClosure>,
344    payload: &serde_json::Value,
345) -> Result<VmValue, VmError> {
346    let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
347        return Err(VmError::Runtime(
348            "runtime hook requires an async builtin VM context".to_string(),
349        ));
350    };
351    let arg = crate::stdlib::json_to_vm_value(payload);
352    vm.call_closure_pub(closure, &[arg]).await
353}
354
355fn parse_pre_tool_result(value: VmValue) -> Result<PreToolAction, VmError> {
356    match value {
357        VmValue::Nil => Ok(PreToolAction::Allow),
358        VmValue::Dict(map) => {
359            if let Some(reason) = map.get("deny") {
360                return Ok(PreToolAction::Deny(reason.display()));
361            }
362            if let Some(args) = map.get("args") {
363                return Ok(PreToolAction::Modify(crate::llm::vm_value_to_json(args)));
364            }
365            Ok(PreToolAction::Allow)
366        }
367        other => Err(VmError::Runtime(format!(
368            "PreToolUse hook must return nil or {{deny, args}}, got {}",
369            other.type_name()
370        ))),
371    }
372}
373
374fn parse_post_tool_result(value: VmValue) -> Result<PostToolAction, VmError> {
375    match value {
376        VmValue::Nil => Ok(PostToolAction::Pass),
377        VmValue::String(text) => Ok(PostToolAction::Modify(text.to_string())),
378        VmValue::Dict(map) => {
379            if let Some(result) = map.get("result") {
380                return Ok(PostToolAction::Modify(result.display()));
381            }
382            Ok(PostToolAction::Pass)
383        }
384        other => Err(VmError::Runtime(format!(
385            "PostToolUse hook must return nil, string, or {{result}}, got {}",
386            other.type_name()
387        ))),
388    }
389}
390
391/// Run all matching PreToolUse hooks. Returns the final action.
392pub async fn run_pre_tool_hooks(
393    tool_name: &str,
394    args: &serde_json::Value,
395) -> Result<PreToolAction, VmError> {
396    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
397    let mut current_args = args.clone();
398    for hook in hooks
399        .iter()
400        .filter(|hook| hook.event == HookEvent::PreToolUse)
401    {
402        let payload = serde_json::json!({
403            "event": HookEvent::PreToolUse.as_str(),
404            "tool": {
405                "name": tool_name,
406                "args": current_args.clone(),
407            },
408        });
409        if !hook_matches(hook, Some(tool_name), &payload) {
410            continue;
411        }
412        let action = match &hook.handler {
413            RuntimeHookHandler::NativePreTool(pre) => pre(tool_name, &current_args),
414            RuntimeHookHandler::Vm { closure, .. } => {
415                parse_pre_tool_result(invoke_vm_hook(closure, &payload).await?)?
416            }
417            RuntimeHookHandler::NativePostTool(_) => continue,
418        };
419        match action {
420            PreToolAction::Allow => {}
421            PreToolAction::Deny(reason) => return Ok(PreToolAction::Deny(reason)),
422            PreToolAction::Modify(new_args) => {
423                current_args = new_args;
424            }
425        }
426    }
427    if current_args != *args {
428        Ok(PreToolAction::Modify(current_args))
429    } else {
430        Ok(PreToolAction::Allow)
431    }
432}
433
434/// Run all matching PostToolUse hooks. Returns the (possibly modified) result.
435pub async fn run_post_tool_hooks(
436    tool_name: &str,
437    args: &serde_json::Value,
438    result: &str,
439) -> Result<String, VmError> {
440    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
441    let mut current = result.to_string();
442    for hook in hooks
443        .iter()
444        .filter(|hook| hook.event == HookEvent::PostToolUse)
445    {
446        let payload = serde_json::json!({
447            "event": HookEvent::PostToolUse.as_str(),
448            "tool": {
449                "name": tool_name,
450                "args": args,
451            },
452            "result": {
453                "text": current.clone(),
454            },
455        });
456        if !hook_matches(hook, Some(tool_name), &payload) {
457            continue;
458        }
459        let action = match &hook.handler {
460            RuntimeHookHandler::NativePostTool(post) => post(tool_name, &current),
461            RuntimeHookHandler::Vm { closure, .. } => {
462                parse_post_tool_result(invoke_vm_hook(closure, &payload).await?)?
463            }
464            RuntimeHookHandler::NativePreTool(_) => continue,
465        };
466        match action {
467            PostToolAction::Pass => {}
468            PostToolAction::Modify(new_result) => {
469                current = new_result;
470            }
471        }
472    }
473    Ok(current)
474}
475
476pub async fn run_lifecycle_hooks(
477    event: HookEvent,
478    payload: &serde_json::Value,
479) -> Result<(), VmError> {
480    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
481    for hook in hooks.iter().filter(|hook| hook.event == event) {
482        if !hook_matches(hook, None, payload) {
483            continue;
484        }
485        if let RuntimeHookHandler::Vm { closure, .. } = &hook.handler {
486            let _ = invoke_vm_hook(closure, payload).await?;
487        }
488    }
489    Ok(())
490}
491
492pub fn matching_vm_lifecycle_hooks(
493    event: HookEvent,
494    payload: &serde_json::Value,
495) -> Vec<VmLifecycleHookInvocation> {
496    let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
497    hooks
498        .iter()
499        .filter(|hook| hook.event == event)
500        .filter(|hook| hook_matches(hook, None, payload))
501        .filter_map(|hook| match &hook.handler {
502            RuntimeHookHandler::Vm { closure, .. } => Some(VmLifecycleHookInvocation {
503                closure: closure.clone(),
504            }),
505            RuntimeHookHandler::NativePreTool(_) | RuntimeHookHandler::NativePostTool(_) => None,
506        })
507        .collect()
508}