1use std::cell::RefCell;
4use std::rc::Rc;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9use crate::agent_events::WorkerEvent;
10use crate::value::{VmClosure, VmError, VmValue};
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
14pub enum HookEvent {
15 #[serde(rename = "PreToolUse")]
16 PreToolUse,
17 #[serde(rename = "PostToolUse")]
18 PostToolUse,
19 #[serde(rename = "PreAgentTurn")]
20 PreAgentTurn,
21 #[serde(rename = "PostAgentTurn")]
22 PostAgentTurn,
23 #[serde(rename = "WorkerSpawned")]
24 WorkerSpawned,
25 #[serde(rename = "WorkerProgressed")]
26 WorkerProgressed,
27 #[serde(rename = "WorkerWaitingForInput")]
28 WorkerWaitingForInput,
29 #[serde(rename = "WorkerCompleted")]
30 WorkerCompleted,
31 #[serde(rename = "WorkerFailed")]
32 WorkerFailed,
33 #[serde(rename = "WorkerCancelled")]
34 WorkerCancelled,
35 #[serde(rename = "PreStep")]
36 PreStep,
37 #[serde(rename = "PostStep")]
38 PostStep,
39 #[serde(rename = "OnBudgetThreshold")]
40 OnBudgetThreshold,
41 #[serde(rename = "OnApprovalRequested")]
42 OnApprovalRequested,
43 #[serde(rename = "OnHandoffEmitted")]
44 OnHandoffEmitted,
45 #[serde(rename = "OnPersonaPaused")]
46 OnPersonaPaused,
47 #[serde(rename = "OnPersonaResumed")]
48 OnPersonaResumed,
49}
50
51impl HookEvent {
52 pub fn as_str(self) -> &'static str {
53 match self {
54 Self::PreToolUse => "PreToolUse",
55 Self::PostToolUse => "PostToolUse",
56 Self::PreAgentTurn => "PreAgentTurn",
57 Self::PostAgentTurn => "PostAgentTurn",
58 Self::WorkerSpawned => "WorkerSpawned",
59 Self::WorkerProgressed => "WorkerProgressed",
60 Self::WorkerWaitingForInput => "WorkerWaitingForInput",
61 Self::WorkerCompleted => "WorkerCompleted",
62 Self::WorkerFailed => "WorkerFailed",
63 Self::WorkerCancelled => "WorkerCancelled",
64 Self::PreStep => "PreStep",
65 Self::PostStep => "PostStep",
66 Self::OnBudgetThreshold => "OnBudgetThreshold",
67 Self::OnApprovalRequested => "OnApprovalRequested",
68 Self::OnHandoffEmitted => "OnHandoffEmitted",
69 Self::OnPersonaPaused => "OnPersonaPaused",
70 Self::OnPersonaResumed => "OnPersonaResumed",
71 }
72 }
73
74 pub fn from_worker_event(event: WorkerEvent) -> Self {
75 match event {
76 WorkerEvent::WorkerSpawned => Self::WorkerSpawned,
77 WorkerEvent::WorkerProgressed => Self::WorkerProgressed,
78 WorkerEvent::WorkerWaitingForInput => Self::WorkerWaitingForInput,
79 WorkerEvent::WorkerCompleted => Self::WorkerCompleted,
80 WorkerEvent::WorkerFailed => Self::WorkerFailed,
81 WorkerEvent::WorkerCancelled => Self::WorkerCancelled,
82 }
83 }
84}
85
86#[derive(Clone, Debug)]
88pub enum PreToolAction {
89 Allow,
91 Deny(String),
93 Modify(serde_json::Value),
95}
96
97#[derive(Clone, Debug)]
99pub enum PostToolAction {
100 Pass,
102 Modify(String),
104}
105
106pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
108pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
109
110#[derive(Clone)]
112pub struct ToolHook {
113 pub pattern: String,
115 pub pre: Option<PreToolHookFn>,
117 pub post: Option<PostToolHookFn>,
119}
120
121impl std::fmt::Debug for ToolHook {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("ToolHook")
124 .field("pattern", &self.pattern)
125 .field("has_pre", &self.pre.is_some())
126 .field("has_post", &self.post.is_some())
127 .finish()
128 }
129}
130
131#[derive(Clone)]
132enum PatternMatcher {
133 ToolNameGlob(String),
134 EventExpression(String),
135}
136
137impl std::fmt::Debug for PatternMatcher {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 match self {
140 Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
141 Self::EventExpression(pattern) => {
142 f.debug_tuple("EventExpression").field(pattern).finish()
143 }
144 }
145 }
146}
147
148#[derive(Clone)]
149enum RuntimeHookHandler {
150 NativePreTool(PreToolHookFn),
151 NativePostTool(PostToolHookFn),
152 Vm {
153 handler_name: String,
154 closure: Rc<VmClosure>,
155 },
156}
157
158impl std::fmt::Debug for RuntimeHookHandler {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 match self {
161 Self::NativePreTool(_) => f.write_str("NativePreTool(..)"),
162 Self::NativePostTool(_) => f.write_str("NativePostTool(..)"),
163 Self::Vm { handler_name, .. } => f
164 .debug_struct("Vm")
165 .field("handler_name", handler_name)
166 .finish(),
167 }
168 }
169}
170
171#[derive(Clone, Debug)]
172struct RuntimeHook {
173 event: HookEvent,
174 matcher: PatternMatcher,
175 handler: RuntimeHookHandler,
176}
177
178#[derive(Clone, Debug)]
179pub struct VmLifecycleHookInvocation {
180 pub closure: Rc<VmClosure>,
181}
182
183thread_local! {
184 static RUNTIME_HOOKS: RefCell<Vec<RuntimeHook>> = const { RefCell::new(Vec::new()) };
185}
186
187pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
188 if pattern == "*" {
189 return true;
190 }
191 if let Some(prefix) = pattern.strip_suffix('*') {
192 return name.starts_with(prefix);
193 }
194 if let Some(suffix) = pattern.strip_prefix('*') {
195 return name.ends_with(suffix);
196 }
197 pattern == name
198}
199
200pub fn register_tool_hook(hook: ToolHook) {
201 if let Some(pre) = hook.pre {
202 RUNTIME_HOOKS.with(|hooks| {
203 hooks.borrow_mut().push(RuntimeHook {
204 event: HookEvent::PreToolUse,
205 matcher: PatternMatcher::ToolNameGlob(hook.pattern.clone()),
206 handler: RuntimeHookHandler::NativePreTool(pre),
207 });
208 });
209 }
210 if let Some(post) = hook.post {
211 RUNTIME_HOOKS.with(|hooks| {
212 hooks.borrow_mut().push(RuntimeHook {
213 event: HookEvent::PostToolUse,
214 matcher: PatternMatcher::ToolNameGlob(hook.pattern),
215 handler: RuntimeHookHandler::NativePostTool(post),
216 });
217 });
218 }
219}
220
221pub fn register_vm_hook(
222 event: HookEvent,
223 pattern: impl Into<String>,
224 handler_name: impl Into<String>,
225 closure: Rc<VmClosure>,
226) {
227 RUNTIME_HOOKS.with(|hooks| {
228 hooks.borrow_mut().push(RuntimeHook {
229 event,
230 matcher: PatternMatcher::EventExpression(pattern.into()),
231 handler: RuntimeHookHandler::Vm {
232 handler_name: handler_name.into(),
233 closure,
234 },
235 });
236 });
237}
238
239pub fn clear_tool_hooks() {
240 RUNTIME_HOOKS.with(|hooks| {
241 hooks
242 .borrow_mut()
243 .retain(|hook| !matches!(hook.event, HookEvent::PreToolUse | HookEvent::PostToolUse));
244 });
245}
246
247pub fn clear_runtime_hooks() {
248 RUNTIME_HOOKS.with(|hooks| hooks.borrow_mut().clear());
249 super::clear_command_policies();
250}
251
252fn value_at_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
253 let mut current = value;
254 for segment in path.split('.') {
255 let serde_json::Value::Object(map) = current else {
256 return None;
257 };
258 current = map.get(segment)?;
259 }
260 Some(current)
261}
262
263fn value_truthy(value: &serde_json::Value) -> bool {
264 match value {
265 serde_json::Value::Null => false,
266 serde_json::Value::Bool(value) => *value,
267 serde_json::Value::Number(value) => value
268 .as_i64()
269 .map(|number| number != 0)
270 .or_else(|| value.as_u64().map(|number| number != 0))
271 .or_else(|| value.as_f64().map(|number| number != 0.0))
272 .unwrap_or(false),
273 serde_json::Value::String(value) => !value.is_empty(),
274 serde_json::Value::Array(values) => !values.is_empty(),
275 serde_json::Value::Object(values) => !values.is_empty(),
276 }
277}
278
279fn value_to_pattern_string(value: Option<&serde_json::Value>) -> String {
280 match value {
281 Some(serde_json::Value::String(text)) => text.clone(),
282 Some(other) => other.to_string(),
283 None => String::new(),
284 }
285}
286
287fn strip_quoted(value: &str) -> &str {
288 value
289 .trim()
290 .strip_prefix('"')
291 .and_then(|text| text.strip_suffix('"'))
292 .or_else(|| {
293 value
294 .trim()
295 .strip_prefix('\'')
296 .and_then(|text| text.strip_suffix('\''))
297 })
298 .unwrap_or(value.trim())
299}
300
301fn expression_matches(pattern: &str, payload: &serde_json::Value) -> bool {
302 let pattern = pattern.trim();
303 if pattern.is_empty() || pattern == "*" {
304 return true;
305 }
306 if let Some(target) = value_at_path(payload, "target").and_then(serde_json::Value::as_str) {
307 if glob_match(pattern, target) {
308 return true;
309 }
310 }
311 if let Some((lhs, rhs)) = pattern.split_once("=~") {
312 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
313 let regex = strip_quoted(rhs);
314 return Regex::new(regex).is_ok_and(|compiled| compiled.is_match(&value));
315 }
316 if let Some((lhs, rhs)) = pattern.split_once("==") {
317 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
318 return value == strip_quoted(rhs);
319 }
320 if let Some((lhs, rhs)) = pattern.split_once("!=") {
321 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
322 return value != strip_quoted(rhs);
323 }
324 if pattern.contains('.') {
325 return value_at_path(payload, pattern).is_some_and(value_truthy);
326 }
327 glob_match(
328 pattern,
329 &value_to_pattern_string(value_at_path(payload, "tool.name")),
330 )
331}
332
333fn hook_matches(hook: &RuntimeHook, tool_name: Option<&str>, payload: &serde_json::Value) -> bool {
334 match &hook.matcher {
335 PatternMatcher::ToolNameGlob(pattern) => {
336 tool_name.is_some_and(|candidate| glob_match(pattern, candidate))
337 }
338 PatternMatcher::EventExpression(pattern) => expression_matches(pattern, payload),
339 }
340}
341
342async fn invoke_vm_hook(
343 closure: &Rc<VmClosure>,
344 payload: &serde_json::Value,
345) -> Result<VmValue, VmError> {
346 let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
347 return Err(VmError::Runtime(
348 "runtime hook requires an async builtin VM context".to_string(),
349 ));
350 };
351 let arg = crate::stdlib::json_to_vm_value(payload);
352 vm.call_closure_pub(closure, &[arg]).await
353}
354
355fn parse_pre_tool_result(value: VmValue) -> Result<PreToolAction, VmError> {
356 match value {
357 VmValue::Nil => Ok(PreToolAction::Allow),
358 VmValue::Dict(map) => {
359 if let Some(reason) = map.get("deny") {
360 return Ok(PreToolAction::Deny(reason.display()));
361 }
362 if let Some(args) = map.get("args") {
363 return Ok(PreToolAction::Modify(crate::llm::vm_value_to_json(args)));
364 }
365 Ok(PreToolAction::Allow)
366 }
367 other => Err(VmError::Runtime(format!(
368 "PreToolUse hook must return nil or {{deny, args}}, got {}",
369 other.type_name()
370 ))),
371 }
372}
373
374fn parse_post_tool_result(value: VmValue) -> Result<PostToolAction, VmError> {
375 match value {
376 VmValue::Nil => Ok(PostToolAction::Pass),
377 VmValue::String(text) => Ok(PostToolAction::Modify(text.to_string())),
378 VmValue::Dict(map) => {
379 if let Some(result) = map.get("result") {
380 return Ok(PostToolAction::Modify(result.display()));
381 }
382 Ok(PostToolAction::Pass)
383 }
384 other => Err(VmError::Runtime(format!(
385 "PostToolUse hook must return nil, string, or {{result}}, got {}",
386 other.type_name()
387 ))),
388 }
389}
390
391pub async fn run_pre_tool_hooks(
393 tool_name: &str,
394 args: &serde_json::Value,
395) -> Result<PreToolAction, VmError> {
396 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
397 let mut current_args = args.clone();
398 for hook in hooks
399 .iter()
400 .filter(|hook| hook.event == HookEvent::PreToolUse)
401 {
402 let payload = serde_json::json!({
403 "event": HookEvent::PreToolUse.as_str(),
404 "tool": {
405 "name": tool_name,
406 "args": current_args.clone(),
407 },
408 });
409 if !hook_matches(hook, Some(tool_name), &payload) {
410 continue;
411 }
412 let action = match &hook.handler {
413 RuntimeHookHandler::NativePreTool(pre) => pre(tool_name, ¤t_args),
414 RuntimeHookHandler::Vm { closure, .. } => {
415 parse_pre_tool_result(invoke_vm_hook(closure, &payload).await?)?
416 }
417 RuntimeHookHandler::NativePostTool(_) => continue,
418 };
419 match action {
420 PreToolAction::Allow => {}
421 PreToolAction::Deny(reason) => return Ok(PreToolAction::Deny(reason)),
422 PreToolAction::Modify(new_args) => {
423 current_args = new_args;
424 }
425 }
426 }
427 if current_args != *args {
428 Ok(PreToolAction::Modify(current_args))
429 } else {
430 Ok(PreToolAction::Allow)
431 }
432}
433
434pub async fn run_post_tool_hooks(
436 tool_name: &str,
437 args: &serde_json::Value,
438 result: &str,
439) -> Result<String, VmError> {
440 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
441 let mut current = result.to_string();
442 for hook in hooks
443 .iter()
444 .filter(|hook| hook.event == HookEvent::PostToolUse)
445 {
446 let payload = serde_json::json!({
447 "event": HookEvent::PostToolUse.as_str(),
448 "tool": {
449 "name": tool_name,
450 "args": args,
451 },
452 "result": {
453 "text": current.clone(),
454 },
455 });
456 if !hook_matches(hook, Some(tool_name), &payload) {
457 continue;
458 }
459 let action = match &hook.handler {
460 RuntimeHookHandler::NativePostTool(post) => post(tool_name, ¤t),
461 RuntimeHookHandler::Vm { closure, .. } => {
462 parse_post_tool_result(invoke_vm_hook(closure, &payload).await?)?
463 }
464 RuntimeHookHandler::NativePreTool(_) => continue,
465 };
466 match action {
467 PostToolAction::Pass => {}
468 PostToolAction::Modify(new_result) => {
469 current = new_result;
470 }
471 }
472 }
473 Ok(current)
474}
475
476pub async fn run_lifecycle_hooks(
477 event: HookEvent,
478 payload: &serde_json::Value,
479) -> Result<(), VmError> {
480 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
481 for hook in hooks.iter().filter(|hook| hook.event == event) {
482 if !hook_matches(hook, None, payload) {
483 continue;
484 }
485 if let RuntimeHookHandler::Vm { closure, .. } = &hook.handler {
486 let _ = invoke_vm_hook(closure, payload).await?;
487 }
488 }
489 Ok(())
490}
491
492pub fn matching_vm_lifecycle_hooks(
493 event: HookEvent,
494 payload: &serde_json::Value,
495) -> Vec<VmLifecycleHookInvocation> {
496 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
497 hooks
498 .iter()
499 .filter(|hook| hook.event == event)
500 .filter(|hook| hook_matches(hook, None, payload))
501 .filter_map(|hook| match &hook.handler {
502 RuntimeHookHandler::Vm { closure, .. } => Some(VmLifecycleHookInvocation {
503 closure: closure.clone(),
504 }),
505 RuntimeHookHandler::NativePreTool(_) | RuntimeHookHandler::NativePostTool(_) => None,
506 })
507 .collect()
508}