Skip to main content

harn_vm/orchestration/
hooks.rs

1//! Runtime lifecycle hooks — tool, agent-turn, and worker interception.
2
3use std::cell::RefCell;
4use std::rc::Rc;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9use crate::agent_events::WorkerEvent;
10use crate::value::{VmClosure, VmError, VmValue};
11
12/// Manifest / runtime hook event names.
13#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
14pub enum HookEvent {
15    #[serde(rename = "PreToolUse")]
16    PreToolUse,
17    #[serde(rename = "PostToolUse")]
18    PostToolUse,
19    #[serde(rename = "PreAgentTurn")]
20    PreAgentTurn,
21    #[serde(rename = "PostAgentTurn")]
22    PostAgentTurn,
23    #[serde(rename = "WorkerSpawned")]
24    WorkerSpawned,
25    #[serde(rename = "WorkerProgressed")]
26    WorkerProgressed,
27    #[serde(rename = "WorkerWaitingForInput")]
28    WorkerWaitingForInput,
29    #[serde(rename = "WorkerCompleted")]
30    WorkerCompleted,
31    #[serde(rename = "WorkerFailed")]
32    WorkerFailed,
33    #[serde(rename = "WorkerCancelled")]
34    WorkerCancelled,
35    #[serde(rename = "PreStep")]
36    PreStep,
37    #[serde(rename = "PostStep")]
38    PostStep,
39    #[serde(rename = "OnBudgetThreshold")]
40    OnBudgetThreshold,
41    #[serde(rename = "OnApprovalRequested")]
42    OnApprovalRequested,
43    #[serde(rename = "OnHandoffEmitted")]
44    OnHandoffEmitted,
45    #[serde(rename = "OnPersonaPaused")]
46    OnPersonaPaused,
47    #[serde(rename = "OnPersonaResumed")]
48    OnPersonaResumed,
49}
50
51impl HookEvent {
52    pub fn as_str(self) -> &'static str {
53        match self {
54            Self::PreToolUse => "PreToolUse",
55            Self::PostToolUse => "PostToolUse",
56            Self::PreAgentTurn => "PreAgentTurn",
57            Self::PostAgentTurn => "PostAgentTurn",
58            Self::WorkerSpawned => "WorkerSpawned",
59            Self::WorkerProgressed => "WorkerProgressed",
60            Self::WorkerWaitingForInput => "WorkerWaitingForInput",
61            Self::WorkerCompleted => "WorkerCompleted",
62            Self::WorkerFailed => "WorkerFailed",
63            Self::WorkerCancelled => "WorkerCancelled",
64            Self::PreStep => "PreStep",
65            Self::PostStep => "PostStep",
66            Self::OnBudgetThreshold => "OnBudgetThreshold",
67            Self::OnApprovalRequested => "OnApprovalRequested",
68            Self::OnHandoffEmitted => "OnHandoffEmitted",
69            Self::OnPersonaPaused => "OnPersonaPaused",
70            Self::OnPersonaResumed => "OnPersonaResumed",
71        }
72    }
73
74    pub fn from_worker_event(event: WorkerEvent) -> Self {
75        match event {
76            WorkerEvent::WorkerSpawned => Self::WorkerSpawned,
77            WorkerEvent::WorkerProgressed => Self::WorkerProgressed,
78            WorkerEvent::WorkerWaitingForInput => Self::WorkerWaitingForInput,
79            WorkerEvent::WorkerCompleted => Self::WorkerCompleted,
80            WorkerEvent::WorkerFailed => Self::WorkerFailed,
81            WorkerEvent::WorkerCancelled => Self::WorkerCancelled,
82        }
83    }
84}
85
86/// Action returned by a PreToolUse hook.
87#[derive(Clone, Debug)]
88pub enum PreToolAction {
89    /// Allow the tool call to proceed unchanged.
90    Allow,
91    /// Deny the tool call with an explanation.
92    Deny(String),
93    /// Allow but replace the arguments.
94    Modify(serde_json::Value),
95}
96
97/// Action returned by a PostToolUse hook.
98#[derive(Clone, Debug)]
99pub enum PostToolAction {
100    /// Pass the result through unchanged.
101    Pass,
102    /// Replace the result text.
103    Modify(String),
104}
105
106/// Callback types for legacy tool lifecycle hooks.
107pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
108pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
109
110/// A registered tool hook with a name pattern and callbacks.
111#[derive(Clone)]
112pub struct ToolHook {
113    /// Glob-style pattern matched against tool names (e.g. `"*"`, `"exec*"`, `"read_file"`).
114    pub pattern: String,
115    /// Called before tool execution. Return `Deny` to reject, `Modify` to rewrite args.
116    pub pre: Option<PreToolHookFn>,
117    /// Called after tool execution with the result text. Return `Modify` to rewrite.
118    pub post: Option<PostToolHookFn>,
119}
120
121impl std::fmt::Debug for ToolHook {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("ToolHook")
124            .field("pattern", &self.pattern)
125            .field("has_pre", &self.pre.is_some())
126            .field("has_post", &self.post.is_some())
127            .finish()
128    }
129}
130
131#[derive(Clone)]
132enum PatternMatcher {
133    ToolNameGlob(String),
134    EventExpression {
135        source: String,
136        expression: EventPatternExpression,
137    },
138}
139
140#[derive(Clone)]
141enum EventPatternExpression {
142    MatchAll,
143    NeverMatch,
144    Regex { path: String, regex: Regex },
145    Equals { path: String, value: String },
146    NotEquals { path: String, value: String },
147    PathTruthy(String),
148    ToolNameGlob(String),
149}
150
151impl std::fmt::Debug for PatternMatcher {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        match self {
154            Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
155            Self::EventExpression { source, expression } => f
156                .debug_struct("EventExpression")
157                .field("source", source)
158                .field("expression", expression)
159                .finish(),
160        }
161    }
162}
163
164impl std::fmt::Debug for EventPatternExpression {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        match self {
167            Self::MatchAll => f.write_str("MatchAll"),
168            Self::NeverMatch => f.write_str("NeverMatch"),
169            Self::Regex { path, regex } => f
170                .debug_struct("Regex")
171                .field("path", path)
172                .field("regex", &regex.as_str())
173                .finish(),
174            Self::Equals { path, value } => f
175                .debug_struct("Equals")
176                .field("path", path)
177                .field("value", value)
178                .finish(),
179            Self::NotEquals { path, value } => f
180                .debug_struct("NotEquals")
181                .field("path", path)
182                .field("value", value)
183                .finish(),
184            Self::PathTruthy(path) => f.debug_tuple("PathTruthy").field(path).finish(),
185            Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
186        }
187    }
188}
189
190#[derive(Clone)]
191enum RuntimeHookHandler {
192    NativePreTool(PreToolHookFn),
193    NativePostTool(PostToolHookFn),
194    Vm {
195        handler_name: String,
196        closure: Rc<VmClosure>,
197    },
198}
199
200impl std::fmt::Debug for RuntimeHookHandler {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        match self {
203            Self::NativePreTool(_) => f.write_str("NativePreTool(..)"),
204            Self::NativePostTool(_) => f.write_str("NativePostTool(..)"),
205            Self::Vm { handler_name, .. } => f
206                .debug_struct("Vm")
207                .field("handler_name", handler_name)
208                .finish(),
209        }
210    }
211}
212
213#[derive(Clone, Debug)]
214struct RuntimeHook {
215    event: HookEvent,
216    matcher: PatternMatcher,
217    handler: RuntimeHookHandler,
218}
219
220#[derive(Clone, Debug)]
221pub struct VmLifecycleHookInvocation {
222    pub closure: Rc<VmClosure>,
223}
224
225thread_local! {
226    static RUNTIME_HOOKS: RefCell<Vec<RuntimeHook>> = const { RefCell::new(Vec::new()) };
227}
228
229pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
230    if pattern == "*" {
231        return true;
232    }
233    if let Some(prefix) = pattern.strip_suffix('*') {
234        return name.starts_with(prefix);
235    }
236    if let Some(suffix) = pattern.strip_prefix('*') {
237        return name.ends_with(suffix);
238    }
239    pattern == name
240}
241
242pub fn register_tool_hook(hook: ToolHook) {
243    if let Some(pre) = hook.pre {
244        RUNTIME_HOOKS.with(|hooks| {
245            hooks.borrow_mut().push(RuntimeHook {
246                event: HookEvent::PreToolUse,
247                matcher: PatternMatcher::ToolNameGlob(hook.pattern.clone()),
248                handler: RuntimeHookHandler::NativePreTool(pre),
249            });
250        });
251    }
252    if let Some(post) = hook.post {
253        RUNTIME_HOOKS.with(|hooks| {
254            hooks.borrow_mut().push(RuntimeHook {
255                event: HookEvent::PostToolUse,
256                matcher: PatternMatcher::ToolNameGlob(hook.pattern),
257                handler: RuntimeHookHandler::NativePostTool(post),
258            });
259        });
260    }
261}
262
263pub fn register_vm_hook(
264    event: HookEvent,
265    pattern: impl Into<String>,
266    handler_name: impl Into<String>,
267    closure: Rc<VmClosure>,
268) {
269    RUNTIME_HOOKS.with(|hooks| {
270        hooks.borrow_mut().push(RuntimeHook {
271            event,
272            matcher: compile_event_pattern(pattern.into()),
273            handler: RuntimeHookHandler::Vm {
274                handler_name: handler_name.into(),
275                closure,
276            },
277        });
278    });
279}
280
281pub fn clear_tool_hooks() {
282    RUNTIME_HOOKS.with(|hooks| {
283        hooks
284            .borrow_mut()
285            .retain(|hook| !matches!(hook.event, HookEvent::PreToolUse | HookEvent::PostToolUse));
286    });
287}
288
289pub fn clear_runtime_hooks() {
290    RUNTIME_HOOKS.with(|hooks| hooks.borrow_mut().clear());
291    super::clear_command_policies();
292}
293
294fn value_at_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
295    let mut current = value;
296    for segment in path.split('.') {
297        let serde_json::Value::Object(map) = current else {
298            return None;
299        };
300        current = map.get(segment)?;
301    }
302    Some(current)
303}
304
305fn value_truthy(value: &serde_json::Value) -> bool {
306    match value {
307        serde_json::Value::Null => false,
308        serde_json::Value::Bool(value) => *value,
309        serde_json::Value::Number(value) => value
310            .as_i64()
311            .map(|number| number != 0)
312            .or_else(|| value.as_u64().map(|number| number != 0))
313            .or_else(|| value.as_f64().map(|number| number != 0.0))
314            .unwrap_or(false),
315        serde_json::Value::String(value) => !value.is_empty(),
316        serde_json::Value::Array(values) => !values.is_empty(),
317        serde_json::Value::Object(values) => !values.is_empty(),
318    }
319}
320
321fn value_to_pattern_string(value: Option<&serde_json::Value>) -> String {
322    match value {
323        Some(serde_json::Value::String(text)) => text.clone(),
324        Some(other) => other.to_string(),
325        None => String::new(),
326    }
327}
328
329fn strip_quoted(value: &str) -> &str {
330    value
331        .trim()
332        .strip_prefix('"')
333        .and_then(|text| text.strip_suffix('"'))
334        .or_else(|| {
335            value
336                .trim()
337                .strip_prefix('\'')
338                .and_then(|text| text.strip_suffix('\''))
339        })
340        .unwrap_or(value.trim())
341}
342
343fn compile_event_pattern(pattern: String) -> PatternMatcher {
344    let trimmed = pattern.trim();
345    let expression = if trimmed.is_empty() || trimmed == "*" {
346        EventPatternExpression::MatchAll
347    } else if let Some((lhs, rhs)) = trimmed.split_once("=~") {
348        match Regex::new(strip_quoted(rhs)) {
349            Ok(regex) => EventPatternExpression::Regex {
350                path: lhs.trim().to_string(),
351                regex,
352            },
353            Err(_) => EventPatternExpression::NeverMatch,
354        }
355    } else if let Some((lhs, rhs)) = trimmed.split_once("==") {
356        EventPatternExpression::Equals {
357            path: lhs.trim().to_string(),
358            value: strip_quoted(rhs).to_string(),
359        }
360    } else if let Some((lhs, rhs)) = trimmed.split_once("!=") {
361        EventPatternExpression::NotEquals {
362            path: lhs.trim().to_string(),
363            value: strip_quoted(rhs).to_string(),
364        }
365    } else if trimmed.contains('.') {
366        EventPatternExpression::PathTruthy(trimmed.to_string())
367    } else {
368        EventPatternExpression::ToolNameGlob(trimmed.to_string())
369    };
370    PatternMatcher::EventExpression {
371        source: pattern,
372        expression,
373    }
374}
375
376fn expression_matches(
377    source: &str,
378    expression: &EventPatternExpression,
379    payload: &serde_json::Value,
380) -> bool {
381    let pattern = source.trim();
382    if pattern.is_empty() || pattern == "*" {
383        return true;
384    }
385    if let Some(target) = value_at_path(payload, "target").and_then(serde_json::Value::as_str) {
386        if glob_match(pattern, target) {
387            return true;
388        }
389    }
390    match expression {
391        EventPatternExpression::MatchAll => true,
392        EventPatternExpression::NeverMatch => false,
393        EventPatternExpression::Regex { path, regex } => {
394            let value = value_to_pattern_string(value_at_path(payload, path));
395            regex.is_match(&value)
396        }
397        EventPatternExpression::Equals { path, value } => {
398            value_to_pattern_string(value_at_path(payload, path)) == *value
399        }
400        EventPatternExpression::NotEquals { path, value } => {
401            value_to_pattern_string(value_at_path(payload, path)) != *value
402        }
403        EventPatternExpression::PathTruthy(path) => {
404            value_at_path(payload, path).is_some_and(value_truthy)
405        }
406        EventPatternExpression::ToolNameGlob(pattern) => glob_match(
407            pattern,
408            &value_to_pattern_string(value_at_path(payload, "tool.name")),
409        ),
410    }
411}
412
413fn hook_matches(hook: &RuntimeHook, tool_name: Option<&str>, payload: &serde_json::Value) -> bool {
414    match &hook.matcher {
415        PatternMatcher::ToolNameGlob(pattern) => {
416            tool_name.is_some_and(|candidate| glob_match(pattern, candidate))
417        }
418        PatternMatcher::EventExpression { source, expression } => {
419            expression_matches(source, expression, payload)
420        }
421    }
422}
423
424fn runtime_hooks_for_event(event: HookEvent) -> Vec<RuntimeHook> {
425    RUNTIME_HOOKS.with(|hooks| {
426        hooks
427            .borrow()
428            .iter()
429            .filter(|hook| hook.event == event)
430            .cloned()
431            .collect()
432    })
433}
434
435async fn invoke_vm_hook(
436    closure: &Rc<VmClosure>,
437    payload: &serde_json::Value,
438) -> Result<VmValue, VmError> {
439    let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
440        return Err(VmError::Runtime(
441            "runtime hook requires an async builtin VM context".to_string(),
442        ));
443    };
444    let arg = crate::stdlib::json_to_vm_value(payload);
445    vm.call_closure_pub(closure, &[arg]).await
446}
447
448async fn invoke_vm_lifecycle_hooks(
449    closures: Vec<Rc<VmClosure>>,
450    payload: &serde_json::Value,
451) -> Result<(), VmError> {
452    let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
453        return Err(VmError::Runtime(
454            "runtime hook requires an async builtin VM context".to_string(),
455        ));
456    };
457    let arg = crate::stdlib::json_to_vm_value(payload);
458    for closure in closures {
459        let _ = vm.call_closure_pub(&closure, &[arg.clone()]).await?;
460    }
461    Ok(())
462}
463
464fn parse_pre_tool_result(value: VmValue) -> Result<PreToolAction, VmError> {
465    match value {
466        VmValue::Nil => Ok(PreToolAction::Allow),
467        VmValue::Dict(map) => {
468            if let Some(reason) = map.get("deny") {
469                return Ok(PreToolAction::Deny(reason.display()));
470            }
471            if let Some(args) = map.get("args") {
472                return Ok(PreToolAction::Modify(crate::llm::vm_value_to_json(args)));
473            }
474            Ok(PreToolAction::Allow)
475        }
476        other => Err(VmError::Runtime(format!(
477            "PreToolUse hook must return nil or {{deny, args}}, got {}",
478            other.type_name()
479        ))),
480    }
481}
482
483fn parse_post_tool_result(value: VmValue) -> Result<PostToolAction, VmError> {
484    match value {
485        VmValue::Nil => Ok(PostToolAction::Pass),
486        VmValue::String(text) => Ok(PostToolAction::Modify(text.to_string())),
487        VmValue::Dict(map) => {
488            if let Some(result) = map.get("result") {
489                return Ok(PostToolAction::Modify(result.display()));
490            }
491            Ok(PostToolAction::Pass)
492        }
493        other => Err(VmError::Runtime(format!(
494            "PostToolUse hook must return nil, string, or {{result}}, got {}",
495            other.type_name()
496        ))),
497    }
498}
499
500/// Run all matching PreToolUse hooks. Returns the final action.
501pub async fn run_pre_tool_hooks(
502    tool_name: &str,
503    args: &serde_json::Value,
504) -> Result<PreToolAction, VmError> {
505    let hooks = runtime_hooks_for_event(HookEvent::PreToolUse);
506    let mut current_args = args.clone();
507    for hook in &hooks {
508        let payload = if matches!(hook.matcher, PatternMatcher::EventExpression { .. }) {
509            Some(serde_json::json!({
510                "event": HookEvent::PreToolUse.as_str(),
511                "tool": {
512                    "name": tool_name,
513                    "args": current_args.clone(),
514                },
515            }))
516        } else {
517            None
518        };
519        if !hook_matches(
520            hook,
521            Some(tool_name),
522            payload.as_ref().unwrap_or(&serde_json::Value::Null),
523        ) {
524            continue;
525        }
526        let action = match &hook.handler {
527            RuntimeHookHandler::NativePreTool(pre) => pre(tool_name, &current_args),
528            RuntimeHookHandler::Vm { closure, .. } => {
529                let payload = payload.as_ref().ok_or_else(|| {
530                    VmError::Runtime("VM PreToolUse hook requires an event payload".to_string())
531                })?;
532                parse_pre_tool_result(invoke_vm_hook(closure, payload).await?)?
533            }
534            RuntimeHookHandler::NativePostTool(_) => continue,
535        };
536        match action {
537            PreToolAction::Allow => {}
538            PreToolAction::Deny(reason) => return Ok(PreToolAction::Deny(reason)),
539            PreToolAction::Modify(new_args) => {
540                current_args = new_args;
541            }
542        }
543    }
544    if current_args != *args {
545        Ok(PreToolAction::Modify(current_args))
546    } else {
547        Ok(PreToolAction::Allow)
548    }
549}
550
551/// Run all matching PostToolUse hooks. Returns the (possibly modified) result.
552pub async fn run_post_tool_hooks(
553    tool_name: &str,
554    args: &serde_json::Value,
555    result: &str,
556) -> Result<String, VmError> {
557    let hooks = runtime_hooks_for_event(HookEvent::PostToolUse);
558    let mut current = result.to_string();
559    for hook in &hooks {
560        let payload = if matches!(hook.matcher, PatternMatcher::EventExpression { .. }) {
561            Some(serde_json::json!({
562                "event": HookEvent::PostToolUse.as_str(),
563                "tool": {
564                    "name": tool_name,
565                    "args": args,
566                },
567                "result": {
568                    "text": current.clone(),
569                },
570            }))
571        } else {
572            None
573        };
574        if !hook_matches(
575            hook,
576            Some(tool_name),
577            payload.as_ref().unwrap_or(&serde_json::Value::Null),
578        ) {
579            continue;
580        }
581        let action = match &hook.handler {
582            RuntimeHookHandler::NativePostTool(post) => post(tool_name, &current),
583            RuntimeHookHandler::Vm { closure, .. } => {
584                let payload = payload.as_ref().ok_or_else(|| {
585                    VmError::Runtime("VM PostToolUse hook requires an event payload".to_string())
586                })?;
587                parse_post_tool_result(invoke_vm_hook(closure, payload).await?)?
588            }
589            RuntimeHookHandler::NativePreTool(_) => continue,
590        };
591        match action {
592            PostToolAction::Pass => {}
593            PostToolAction::Modify(new_result) => {
594                current = new_result;
595            }
596        }
597    }
598    Ok(current)
599}
600
601pub async fn run_lifecycle_hooks(
602    event: HookEvent,
603    payload: &serde_json::Value,
604) -> Result<(), VmError> {
605    let closures = matching_vm_lifecycle_closures(event, payload);
606    if closures.is_empty() {
607        return Ok(());
608    }
609    invoke_vm_lifecycle_hooks(closures, payload).await
610}
611
612pub fn matching_vm_lifecycle_hooks(
613    event: HookEvent,
614    payload: &serde_json::Value,
615) -> Vec<VmLifecycleHookInvocation> {
616    matching_vm_lifecycle_closures(event, payload)
617        .into_iter()
618        .map(|closure| VmLifecycleHookInvocation { closure })
619        .collect()
620}
621
622fn matching_vm_lifecycle_closures(
623    event: HookEvent,
624    payload: &serde_json::Value,
625) -> Vec<Rc<VmClosure>> {
626    RUNTIME_HOOKS.with(|hooks| {
627        hooks
628            .borrow()
629            .iter()
630            .filter(|hook| hook.event == event)
631            .filter(|hook| hook_matches(hook, None, payload))
632            .filter_map(|hook| match &hook.handler {
633                RuntimeHookHandler::Vm { closure, .. } => Some(Rc::clone(closure)),
634                RuntimeHookHandler::NativePreTool(_) | RuntimeHookHandler::NativePostTool(_) => {
635                    None
636                }
637            })
638            .collect()
639    })
640}