1use std::cell::RefCell;
4use std::rc::Rc;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9use crate::agent_events::WorkerEvent;
10use crate::value::{VmClosure, VmError, VmValue};
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
14pub enum HookEvent {
15 #[serde(rename = "PreToolUse")]
16 PreToolUse,
17 #[serde(rename = "PostToolUse")]
18 PostToolUse,
19 #[serde(rename = "PreAgentTurn")]
20 PreAgentTurn,
21 #[serde(rename = "PostAgentTurn")]
22 PostAgentTurn,
23 #[serde(rename = "WorkerSpawned")]
24 WorkerSpawned,
25 #[serde(rename = "WorkerProgressed")]
26 WorkerProgressed,
27 #[serde(rename = "WorkerWaitingForInput")]
28 WorkerWaitingForInput,
29 #[serde(rename = "WorkerCompleted")]
30 WorkerCompleted,
31 #[serde(rename = "WorkerFailed")]
32 WorkerFailed,
33 #[serde(rename = "WorkerCancelled")]
34 WorkerCancelled,
35 #[serde(rename = "PreStep")]
36 PreStep,
37 #[serde(rename = "PostStep")]
38 PostStep,
39 #[serde(rename = "OnBudgetThreshold")]
40 OnBudgetThreshold,
41 #[serde(rename = "OnApprovalRequested")]
42 OnApprovalRequested,
43 #[serde(rename = "OnHandoffEmitted")]
44 OnHandoffEmitted,
45 #[serde(rename = "OnPersonaPaused")]
46 OnPersonaPaused,
47 #[serde(rename = "OnPersonaResumed")]
48 OnPersonaResumed,
49}
50
51impl HookEvent {
52 pub fn as_str(self) -> &'static str {
53 match self {
54 Self::PreToolUse => "PreToolUse",
55 Self::PostToolUse => "PostToolUse",
56 Self::PreAgentTurn => "PreAgentTurn",
57 Self::PostAgentTurn => "PostAgentTurn",
58 Self::WorkerSpawned => "WorkerSpawned",
59 Self::WorkerProgressed => "WorkerProgressed",
60 Self::WorkerWaitingForInput => "WorkerWaitingForInput",
61 Self::WorkerCompleted => "WorkerCompleted",
62 Self::WorkerFailed => "WorkerFailed",
63 Self::WorkerCancelled => "WorkerCancelled",
64 Self::PreStep => "PreStep",
65 Self::PostStep => "PostStep",
66 Self::OnBudgetThreshold => "OnBudgetThreshold",
67 Self::OnApprovalRequested => "OnApprovalRequested",
68 Self::OnHandoffEmitted => "OnHandoffEmitted",
69 Self::OnPersonaPaused => "OnPersonaPaused",
70 Self::OnPersonaResumed => "OnPersonaResumed",
71 }
72 }
73
74 pub fn from_worker_event(event: WorkerEvent) -> Self {
75 match event {
76 WorkerEvent::WorkerSpawned => Self::WorkerSpawned,
77 WorkerEvent::WorkerProgressed => Self::WorkerProgressed,
78 WorkerEvent::WorkerWaitingForInput => Self::WorkerWaitingForInput,
79 WorkerEvent::WorkerCompleted => Self::WorkerCompleted,
80 WorkerEvent::WorkerFailed => Self::WorkerFailed,
81 WorkerEvent::WorkerCancelled => Self::WorkerCancelled,
82 }
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum PreToolAction {
89 Allow,
91 Deny(String),
93 Modify(serde_json::Value),
95}
96
97#[derive(Clone, Debug)]
99pub enum PostToolAction {
100 Pass,
102 Modify(String),
104}
105
106pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
108pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
109
110#[derive(Clone)]
112pub struct ToolHook {
113 pub pattern: String,
115 pub pre: Option<PreToolHookFn>,
117 pub post: Option<PostToolHookFn>,
119}
120
121impl std::fmt::Debug for ToolHook {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("ToolHook")
124 .field("pattern", &self.pattern)
125 .field("has_pre", &self.pre.is_some())
126 .field("has_post", &self.post.is_some())
127 .finish()
128 }
129}
130
131#[derive(Clone)]
132enum PatternMatcher {
133 ToolNameGlob(String),
134 EventExpression {
135 source: String,
136 expression: EventPatternExpression,
137 },
138}
139
140#[derive(Clone)]
141enum EventPatternExpression {
142 MatchAll,
143 NeverMatch,
144 Regex { path: String, regex: Regex },
145 Equals { path: String, value: String },
146 NotEquals { path: String, value: String },
147 PathTruthy(String),
148 ToolNameGlob(String),
149}
150
151impl std::fmt::Debug for PatternMatcher {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 match self {
154 Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
155 Self::EventExpression { source, expression } => f
156 .debug_struct("EventExpression")
157 .field("source", source)
158 .field("expression", expression)
159 .finish(),
160 }
161 }
162}
163
164impl std::fmt::Debug for EventPatternExpression {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 match self {
167 Self::MatchAll => f.write_str("MatchAll"),
168 Self::NeverMatch => f.write_str("NeverMatch"),
169 Self::Regex { path, regex } => f
170 .debug_struct("Regex")
171 .field("path", path)
172 .field("regex", ®ex.as_str())
173 .finish(),
174 Self::Equals { path, value } => f
175 .debug_struct("Equals")
176 .field("path", path)
177 .field("value", value)
178 .finish(),
179 Self::NotEquals { path, value } => f
180 .debug_struct("NotEquals")
181 .field("path", path)
182 .field("value", value)
183 .finish(),
184 Self::PathTruthy(path) => f.debug_tuple("PathTruthy").field(path).finish(),
185 Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
186 }
187 }
188}
189
190#[derive(Clone)]
191enum RuntimeHookHandler {
192 NativePreTool(PreToolHookFn),
193 NativePostTool(PostToolHookFn),
194 Vm {
195 handler_name: String,
196 closure: Rc<VmClosure>,
197 },
198}
199
200impl std::fmt::Debug for RuntimeHookHandler {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 match self {
203 Self::NativePreTool(_) => f.write_str("NativePreTool(..)"),
204 Self::NativePostTool(_) => f.write_str("NativePostTool(..)"),
205 Self::Vm { handler_name, .. } => f
206 .debug_struct("Vm")
207 .field("handler_name", handler_name)
208 .finish(),
209 }
210 }
211}
212
213#[derive(Clone, Debug)]
214struct RuntimeHook {
215 event: HookEvent,
216 matcher: PatternMatcher,
217 handler: RuntimeHookHandler,
218}
219
220#[derive(Clone, Debug)]
221pub struct VmLifecycleHookInvocation {
222 pub closure: Rc<VmClosure>,
223}
224
225thread_local! {
226 static RUNTIME_HOOKS: RefCell<Vec<RuntimeHook>> = const { RefCell::new(Vec::new()) };
227}
228
229pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
230 if pattern == "*" {
231 return true;
232 }
233 if let Some(prefix) = pattern.strip_suffix('*') {
234 return name.starts_with(prefix);
235 }
236 if let Some(suffix) = pattern.strip_prefix('*') {
237 return name.ends_with(suffix);
238 }
239 pattern == name
240}
241
242pub fn register_tool_hook(hook: ToolHook) {
243 if let Some(pre) = hook.pre {
244 RUNTIME_HOOKS.with(|hooks| {
245 hooks.borrow_mut().push(RuntimeHook {
246 event: HookEvent::PreToolUse,
247 matcher: PatternMatcher::ToolNameGlob(hook.pattern.clone()),
248 handler: RuntimeHookHandler::NativePreTool(pre),
249 });
250 });
251 }
252 if let Some(post) = hook.post {
253 RUNTIME_HOOKS.with(|hooks| {
254 hooks.borrow_mut().push(RuntimeHook {
255 event: HookEvent::PostToolUse,
256 matcher: PatternMatcher::ToolNameGlob(hook.pattern),
257 handler: RuntimeHookHandler::NativePostTool(post),
258 });
259 });
260 }
261}
262
263pub fn register_vm_hook(
264 event: HookEvent,
265 pattern: impl Into<String>,
266 handler_name: impl Into<String>,
267 closure: Rc<VmClosure>,
268) {
269 RUNTIME_HOOKS.with(|hooks| {
270 hooks.borrow_mut().push(RuntimeHook {
271 event,
272 matcher: compile_event_pattern(pattern.into()),
273 handler: RuntimeHookHandler::Vm {
274 handler_name: handler_name.into(),
275 closure,
276 },
277 });
278 });
279}
280
281pub fn clear_tool_hooks() {
282 RUNTIME_HOOKS.with(|hooks| {
283 hooks
284 .borrow_mut()
285 .retain(|hook| !matches!(hook.event, HookEvent::PreToolUse | HookEvent::PostToolUse));
286 });
287}
288
289pub fn clear_runtime_hooks() {
290 RUNTIME_HOOKS.with(|hooks| hooks.borrow_mut().clear());
291 super::clear_command_policies();
292}
293
294fn value_at_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
295 let mut current = value;
296 for segment in path.split('.') {
297 let serde_json::Value::Object(map) = current else {
298 return None;
299 };
300 current = map.get(segment)?;
301 }
302 Some(current)
303}
304
305fn value_truthy(value: &serde_json::Value) -> bool {
306 match value {
307 serde_json::Value::Null => false,
308 serde_json::Value::Bool(value) => *value,
309 serde_json::Value::Number(value) => value
310 .as_i64()
311 .map(|number| number != 0)
312 .or_else(|| value.as_u64().map(|number| number != 0))
313 .or_else(|| value.as_f64().map(|number| number != 0.0))
314 .unwrap_or(false),
315 serde_json::Value::String(value) => !value.is_empty(),
316 serde_json::Value::Array(values) => !values.is_empty(),
317 serde_json::Value::Object(values) => !values.is_empty(),
318 }
319}
320
321fn value_to_pattern_string(value: Option<&serde_json::Value>) -> String {
322 match value {
323 Some(serde_json::Value::String(text)) => text.clone(),
324 Some(other) => other.to_string(),
325 None => String::new(),
326 }
327}
328
329fn strip_quoted(value: &str) -> &str {
330 value
331 .trim()
332 .strip_prefix('"')
333 .and_then(|text| text.strip_suffix('"'))
334 .or_else(|| {
335 value
336 .trim()
337 .strip_prefix('\'')
338 .and_then(|text| text.strip_suffix('\''))
339 })
340 .unwrap_or(value.trim())
341}
342
343fn compile_event_pattern(pattern: String) -> PatternMatcher {
344 let trimmed = pattern.trim();
345 let expression = if trimmed.is_empty() || trimmed == "*" {
346 EventPatternExpression::MatchAll
347 } else if let Some((lhs, rhs)) = trimmed.split_once("=~") {
348 match Regex::new(strip_quoted(rhs)) {
349 Ok(regex) => EventPatternExpression::Regex {
350 path: lhs.trim().to_string(),
351 regex,
352 },
353 Err(_) => EventPatternExpression::NeverMatch,
354 }
355 } else if let Some((lhs, rhs)) = trimmed.split_once("==") {
356 EventPatternExpression::Equals {
357 path: lhs.trim().to_string(),
358 value: strip_quoted(rhs).to_string(),
359 }
360 } else if let Some((lhs, rhs)) = trimmed.split_once("!=") {
361 EventPatternExpression::NotEquals {
362 path: lhs.trim().to_string(),
363 value: strip_quoted(rhs).to_string(),
364 }
365 } else if trimmed.contains('.') {
366 EventPatternExpression::PathTruthy(trimmed.to_string())
367 } else {
368 EventPatternExpression::ToolNameGlob(trimmed.to_string())
369 };
370 PatternMatcher::EventExpression {
371 source: pattern,
372 expression,
373 }
374}
375
376fn expression_matches(
377 source: &str,
378 expression: &EventPatternExpression,
379 payload: &serde_json::Value,
380) -> bool {
381 let pattern = source.trim();
382 if pattern.is_empty() || pattern == "*" {
383 return true;
384 }
385 if let Some(target) = value_at_path(payload, "target").and_then(serde_json::Value::as_str) {
386 if glob_match(pattern, target) {
387 return true;
388 }
389 }
390 match expression {
391 EventPatternExpression::MatchAll => true,
392 EventPatternExpression::NeverMatch => false,
393 EventPatternExpression::Regex { path, regex } => {
394 let value = value_to_pattern_string(value_at_path(payload, path));
395 regex.is_match(&value)
396 }
397 EventPatternExpression::Equals { path, value } => {
398 value_to_pattern_string(value_at_path(payload, path)) == *value
399 }
400 EventPatternExpression::NotEquals { path, value } => {
401 value_to_pattern_string(value_at_path(payload, path)) != *value
402 }
403 EventPatternExpression::PathTruthy(path) => {
404 value_at_path(payload, path).is_some_and(value_truthy)
405 }
406 EventPatternExpression::ToolNameGlob(pattern) => glob_match(
407 pattern,
408 &value_to_pattern_string(value_at_path(payload, "tool.name")),
409 ),
410 }
411}
412
413fn hook_matches(hook: &RuntimeHook, tool_name: Option<&str>, payload: &serde_json::Value) -> bool {
414 match &hook.matcher {
415 PatternMatcher::ToolNameGlob(pattern) => {
416 tool_name.is_some_and(|candidate| glob_match(pattern, candidate))
417 }
418 PatternMatcher::EventExpression { source, expression } => {
419 expression_matches(source, expression, payload)
420 }
421 }
422}
423
424fn runtime_hooks_for_event(event: HookEvent) -> Vec<RuntimeHook> {
425 RUNTIME_HOOKS.with(|hooks| {
426 hooks
427 .borrow()
428 .iter()
429 .filter(|hook| hook.event == event)
430 .cloned()
431 .collect()
432 })
433}
434
435async fn invoke_vm_hook(
436 closure: &Rc<VmClosure>,
437 payload: &serde_json::Value,
438) -> Result<VmValue, VmError> {
439 let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
440 return Err(VmError::Runtime(
441 "runtime hook requires an async builtin VM context".to_string(),
442 ));
443 };
444 let arg = crate::stdlib::json_to_vm_value(payload);
445 vm.call_closure_pub(closure, &[arg]).await
446}
447
448async fn invoke_vm_lifecycle_hooks(
449 closures: Vec<Rc<VmClosure>>,
450 payload: &serde_json::Value,
451) -> Result<(), VmError> {
452 let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
453 return Err(VmError::Runtime(
454 "runtime hook requires an async builtin VM context".to_string(),
455 ));
456 };
457 let arg = crate::stdlib::json_to_vm_value(payload);
458 for closure in closures {
459 let _ = vm.call_closure_pub(&closure, &[arg.clone()]).await?;
460 }
461 Ok(())
462}
463
464fn parse_pre_tool_result(value: VmValue) -> Result<PreToolAction, VmError> {
465 match value {
466 VmValue::Nil => Ok(PreToolAction::Allow),
467 VmValue::Dict(map) => {
468 if let Some(reason) = map.get("deny") {
469 return Ok(PreToolAction::Deny(reason.display()));
470 }
471 if let Some(args) = map.get("args") {
472 return Ok(PreToolAction::Modify(crate::llm::vm_value_to_json(args)));
473 }
474 Ok(PreToolAction::Allow)
475 }
476 other => Err(VmError::Runtime(format!(
477 "PreToolUse hook must return nil or {{deny, args}}, got {}",
478 other.type_name()
479 ))),
480 }
481}
482
483fn parse_post_tool_result(value: VmValue) -> Result<PostToolAction, VmError> {
484 match value {
485 VmValue::Nil => Ok(PostToolAction::Pass),
486 VmValue::String(text) => Ok(PostToolAction::Modify(text.to_string())),
487 VmValue::Dict(map) => {
488 if let Some(result) = map.get("result") {
489 return Ok(PostToolAction::Modify(result.display()));
490 }
491 Ok(PostToolAction::Pass)
492 }
493 other => Err(VmError::Runtime(format!(
494 "PostToolUse hook must return nil, string, or {{result}}, got {}",
495 other.type_name()
496 ))),
497 }
498}
499
500pub async fn run_pre_tool_hooks(
502 tool_name: &str,
503 args: &serde_json::Value,
504) -> Result<PreToolAction, VmError> {
505 let hooks = runtime_hooks_for_event(HookEvent::PreToolUse);
506 let mut current_args = args.clone();
507 for hook in &hooks {
508 let payload = if matches!(hook.matcher, PatternMatcher::EventExpression { .. }) {
509 Some(serde_json::json!({
510 "event": HookEvent::PreToolUse.as_str(),
511 "tool": {
512 "name": tool_name,
513 "args": current_args.clone(),
514 },
515 }))
516 } else {
517 None
518 };
519 if !hook_matches(
520 hook,
521 Some(tool_name),
522 payload.as_ref().unwrap_or(&serde_json::Value::Null),
523 ) {
524 continue;
525 }
526 let action = match &hook.handler {
527 RuntimeHookHandler::NativePreTool(pre) => pre(tool_name, ¤t_args),
528 RuntimeHookHandler::Vm { closure, .. } => {
529 let payload = payload.as_ref().ok_or_else(|| {
530 VmError::Runtime("VM PreToolUse hook requires an event payload".to_string())
531 })?;
532 parse_pre_tool_result(invoke_vm_hook(closure, payload).await?)?
533 }
534 RuntimeHookHandler::NativePostTool(_) => continue,
535 };
536 match action {
537 PreToolAction::Allow => {}
538 PreToolAction::Deny(reason) => return Ok(PreToolAction::Deny(reason)),
539 PreToolAction::Modify(new_args) => {
540 current_args = new_args;
541 }
542 }
543 }
544 if current_args != *args {
545 Ok(PreToolAction::Modify(current_args))
546 } else {
547 Ok(PreToolAction::Allow)
548 }
549}
550
551pub async fn run_post_tool_hooks(
553 tool_name: &str,
554 args: &serde_json::Value,
555 result: &str,
556) -> Result<String, VmError> {
557 let hooks = runtime_hooks_for_event(HookEvent::PostToolUse);
558 let mut current = result.to_string();
559 for hook in &hooks {
560 let payload = if matches!(hook.matcher, PatternMatcher::EventExpression { .. }) {
561 Some(serde_json::json!({
562 "event": HookEvent::PostToolUse.as_str(),
563 "tool": {
564 "name": tool_name,
565 "args": args,
566 },
567 "result": {
568 "text": current.clone(),
569 },
570 }))
571 } else {
572 None
573 };
574 if !hook_matches(
575 hook,
576 Some(tool_name),
577 payload.as_ref().unwrap_or(&serde_json::Value::Null),
578 ) {
579 continue;
580 }
581 let action = match &hook.handler {
582 RuntimeHookHandler::NativePostTool(post) => post(tool_name, ¤t),
583 RuntimeHookHandler::Vm { closure, .. } => {
584 let payload = payload.as_ref().ok_or_else(|| {
585 VmError::Runtime("VM PostToolUse hook requires an event payload".to_string())
586 })?;
587 parse_post_tool_result(invoke_vm_hook(closure, payload).await?)?
588 }
589 RuntimeHookHandler::NativePreTool(_) => continue,
590 };
591 match action {
592 PostToolAction::Pass => {}
593 PostToolAction::Modify(new_result) => {
594 current = new_result;
595 }
596 }
597 }
598 Ok(current)
599}
600
601pub async fn run_lifecycle_hooks(
602 event: HookEvent,
603 payload: &serde_json::Value,
604) -> Result<(), VmError> {
605 let closures = matching_vm_lifecycle_closures(event, payload);
606 if closures.is_empty() {
607 return Ok(());
608 }
609 invoke_vm_lifecycle_hooks(closures, payload).await
610}
611
612pub fn matching_vm_lifecycle_hooks(
613 event: HookEvent,
614 payload: &serde_json::Value,
615) -> Vec<VmLifecycleHookInvocation> {
616 matching_vm_lifecycle_closures(event, payload)
617 .into_iter()
618 .map(|closure| VmLifecycleHookInvocation { closure })
619 .collect()
620}
621
622fn matching_vm_lifecycle_closures(
623 event: HookEvent,
624 payload: &serde_json::Value,
625) -> Vec<Rc<VmClosure>> {
626 RUNTIME_HOOKS.with(|hooks| {
627 hooks
628 .borrow()
629 .iter()
630 .filter(|hook| hook.event == event)
631 .filter(|hook| hook_matches(hook, None, payload))
632 .filter_map(|hook| match &hook.handler {
633 RuntimeHookHandler::Vm { closure, .. } => Some(Rc::clone(closure)),
634 RuntimeHookHandler::NativePreTool(_) | RuntimeHookHandler::NativePostTool(_) => {
635 None
636 }
637 })
638 .collect()
639 })
640}