1use std::cell::RefCell;
4use std::rc::Rc;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9use crate::agent_events::WorkerEvent;
10use crate::value::{VmClosure, VmError, VmValue};
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
14pub enum HookEvent {
15 #[serde(rename = "PreToolUse")]
16 PreToolUse,
17 #[serde(rename = "PostToolUse")]
18 PostToolUse,
19 #[serde(rename = "PreAgentTurn")]
20 PreAgentTurn,
21 #[serde(rename = "PostAgentTurn")]
22 PostAgentTurn,
23 #[serde(rename = "WorkerSpawned")]
24 WorkerSpawned,
25 #[serde(rename = "WorkerProgressed")]
26 WorkerProgressed,
27 #[serde(rename = "WorkerWaitingForInput")]
28 WorkerWaitingForInput,
29 #[serde(rename = "WorkerCompleted")]
30 WorkerCompleted,
31 #[serde(rename = "WorkerFailed")]
32 WorkerFailed,
33 #[serde(rename = "WorkerCancelled")]
34 WorkerCancelled,
35}
36
37impl HookEvent {
38 pub fn as_str(self) -> &'static str {
39 match self {
40 Self::PreToolUse => "PreToolUse",
41 Self::PostToolUse => "PostToolUse",
42 Self::PreAgentTurn => "PreAgentTurn",
43 Self::PostAgentTurn => "PostAgentTurn",
44 Self::WorkerSpawned => "WorkerSpawned",
45 Self::WorkerProgressed => "WorkerProgressed",
46 Self::WorkerWaitingForInput => "WorkerWaitingForInput",
47 Self::WorkerCompleted => "WorkerCompleted",
48 Self::WorkerFailed => "WorkerFailed",
49 Self::WorkerCancelled => "WorkerCancelled",
50 }
51 }
52
53 pub fn from_worker_event(event: WorkerEvent) -> Self {
54 match event {
55 WorkerEvent::WorkerSpawned => Self::WorkerSpawned,
56 WorkerEvent::WorkerProgressed => Self::WorkerProgressed,
57 WorkerEvent::WorkerWaitingForInput => Self::WorkerWaitingForInput,
58 WorkerEvent::WorkerCompleted => Self::WorkerCompleted,
59 WorkerEvent::WorkerFailed => Self::WorkerFailed,
60 WorkerEvent::WorkerCancelled => Self::WorkerCancelled,
61 }
62 }
63}
64
65#[derive(Clone, Debug)]
67pub enum PreToolAction {
68 Allow,
70 Deny(String),
72 Modify(serde_json::Value),
74}
75
76#[derive(Clone, Debug)]
78pub enum PostToolAction {
79 Pass,
81 Modify(String),
83}
84
85pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
87pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
88
89#[derive(Clone)]
91pub struct ToolHook {
92 pub pattern: String,
94 pub pre: Option<PreToolHookFn>,
96 pub post: Option<PostToolHookFn>,
98}
99
100impl std::fmt::Debug for ToolHook {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("ToolHook")
103 .field("pattern", &self.pattern)
104 .field("has_pre", &self.pre.is_some())
105 .field("has_post", &self.post.is_some())
106 .finish()
107 }
108}
109
110#[derive(Clone)]
111enum PatternMatcher {
112 ToolNameGlob(String),
113 EventExpression(String),
114}
115
116impl std::fmt::Debug for PatternMatcher {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 match self {
119 Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
120 Self::EventExpression(pattern) => {
121 f.debug_tuple("EventExpression").field(pattern).finish()
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128enum RuntimeHookHandler {
129 NativePreTool(PreToolHookFn),
130 NativePostTool(PostToolHookFn),
131 Vm {
132 handler_name: String,
133 closure: Rc<VmClosure>,
134 },
135}
136
137impl std::fmt::Debug for RuntimeHookHandler {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 match self {
140 Self::NativePreTool(_) => f.write_str("NativePreTool(..)"),
141 Self::NativePostTool(_) => f.write_str("NativePostTool(..)"),
142 Self::Vm { handler_name, .. } => f
143 .debug_struct("Vm")
144 .field("handler_name", handler_name)
145 .finish(),
146 }
147 }
148}
149
150#[derive(Clone, Debug)]
151struct RuntimeHook {
152 event: HookEvent,
153 matcher: PatternMatcher,
154 handler: RuntimeHookHandler,
155}
156
157thread_local! {
158 static RUNTIME_HOOKS: RefCell<Vec<RuntimeHook>> = const { RefCell::new(Vec::new()) };
159}
160
161pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
162 if pattern == "*" {
163 return true;
164 }
165 if let Some(prefix) = pattern.strip_suffix('*') {
166 return name.starts_with(prefix);
167 }
168 if let Some(suffix) = pattern.strip_prefix('*') {
169 return name.ends_with(suffix);
170 }
171 pattern == name
172}
173
174pub fn register_tool_hook(hook: ToolHook) {
175 if let Some(pre) = hook.pre {
176 RUNTIME_HOOKS.with(|hooks| {
177 hooks.borrow_mut().push(RuntimeHook {
178 event: HookEvent::PreToolUse,
179 matcher: PatternMatcher::ToolNameGlob(hook.pattern.clone()),
180 handler: RuntimeHookHandler::NativePreTool(pre),
181 });
182 });
183 }
184 if let Some(post) = hook.post {
185 RUNTIME_HOOKS.with(|hooks| {
186 hooks.borrow_mut().push(RuntimeHook {
187 event: HookEvent::PostToolUse,
188 matcher: PatternMatcher::ToolNameGlob(hook.pattern),
189 handler: RuntimeHookHandler::NativePostTool(post),
190 });
191 });
192 }
193}
194
195pub fn register_vm_hook(
196 event: HookEvent,
197 pattern: impl Into<String>,
198 handler_name: impl Into<String>,
199 closure: Rc<VmClosure>,
200) {
201 RUNTIME_HOOKS.with(|hooks| {
202 hooks.borrow_mut().push(RuntimeHook {
203 event,
204 matcher: PatternMatcher::EventExpression(pattern.into()),
205 handler: RuntimeHookHandler::Vm {
206 handler_name: handler_name.into(),
207 closure,
208 },
209 });
210 });
211}
212
213pub fn clear_tool_hooks() {
214 RUNTIME_HOOKS.with(|hooks| {
215 hooks
216 .borrow_mut()
217 .retain(|hook| !matches!(hook.event, HookEvent::PreToolUse | HookEvent::PostToolUse));
218 });
219}
220
221pub fn clear_runtime_hooks() {
222 RUNTIME_HOOKS.with(|hooks| hooks.borrow_mut().clear());
223 super::clear_command_policies();
224}
225
226fn value_at_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
227 let mut current = value;
228 for segment in path.split('.') {
229 let serde_json::Value::Object(map) = current else {
230 return None;
231 };
232 current = map.get(segment)?;
233 }
234 Some(current)
235}
236
237fn value_truthy(value: &serde_json::Value) -> bool {
238 match value {
239 serde_json::Value::Null => false,
240 serde_json::Value::Bool(value) => *value,
241 serde_json::Value::Number(value) => value
242 .as_i64()
243 .map(|number| number != 0)
244 .or_else(|| value.as_u64().map(|number| number != 0))
245 .or_else(|| value.as_f64().map(|number| number != 0.0))
246 .unwrap_or(false),
247 serde_json::Value::String(value) => !value.is_empty(),
248 serde_json::Value::Array(values) => !values.is_empty(),
249 serde_json::Value::Object(values) => !values.is_empty(),
250 }
251}
252
253fn value_to_pattern_string(value: Option<&serde_json::Value>) -> String {
254 match value {
255 Some(serde_json::Value::String(text)) => text.clone(),
256 Some(other) => other.to_string(),
257 None => String::new(),
258 }
259}
260
261fn strip_quoted(value: &str) -> &str {
262 value
263 .trim()
264 .strip_prefix('"')
265 .and_then(|text| text.strip_suffix('"'))
266 .or_else(|| {
267 value
268 .trim()
269 .strip_prefix('\'')
270 .and_then(|text| text.strip_suffix('\''))
271 })
272 .unwrap_or(value.trim())
273}
274
275fn expression_matches(pattern: &str, payload: &serde_json::Value) -> bool {
276 let pattern = pattern.trim();
277 if pattern.is_empty() || pattern == "*" {
278 return true;
279 }
280 if let Some((lhs, rhs)) = pattern.split_once("=~") {
281 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
282 let regex = strip_quoted(rhs);
283 return Regex::new(regex).is_ok_and(|compiled| compiled.is_match(&value));
284 }
285 if let Some((lhs, rhs)) = pattern.split_once("==") {
286 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
287 return value == strip_quoted(rhs);
288 }
289 if let Some((lhs, rhs)) = pattern.split_once("!=") {
290 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
291 return value != strip_quoted(rhs);
292 }
293 if pattern.contains('.') {
294 return value_at_path(payload, pattern).is_some_and(value_truthy);
295 }
296 glob_match(
297 pattern,
298 &value_to_pattern_string(value_at_path(payload, "tool.name")),
299 )
300}
301
302fn hook_matches(hook: &RuntimeHook, tool_name: Option<&str>, payload: &serde_json::Value) -> bool {
303 match &hook.matcher {
304 PatternMatcher::ToolNameGlob(pattern) => {
305 tool_name.is_some_and(|candidate| glob_match(pattern, candidate))
306 }
307 PatternMatcher::EventExpression(pattern) => expression_matches(pattern, payload),
308 }
309}
310
311async fn invoke_vm_hook(
312 closure: &Rc<VmClosure>,
313 payload: &serde_json::Value,
314) -> Result<VmValue, VmError> {
315 let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
316 return Err(VmError::Runtime(
317 "runtime hook requires an async builtin VM context".to_string(),
318 ));
319 };
320 let arg = crate::stdlib::json_to_vm_value(payload);
321 vm.call_closure_pub(closure, &[arg]).await
322}
323
324fn parse_pre_tool_result(value: VmValue) -> Result<PreToolAction, VmError> {
325 match value {
326 VmValue::Nil => Ok(PreToolAction::Allow),
327 VmValue::Dict(map) => {
328 if let Some(reason) = map.get("deny") {
329 return Ok(PreToolAction::Deny(reason.display()));
330 }
331 if let Some(args) = map.get("args") {
332 return Ok(PreToolAction::Modify(crate::llm::vm_value_to_json(args)));
333 }
334 Ok(PreToolAction::Allow)
335 }
336 other => Err(VmError::Runtime(format!(
337 "PreToolUse hook must return nil or {{deny, args}}, got {}",
338 other.type_name()
339 ))),
340 }
341}
342
343fn parse_post_tool_result(value: VmValue) -> Result<PostToolAction, VmError> {
344 match value {
345 VmValue::Nil => Ok(PostToolAction::Pass),
346 VmValue::String(text) => Ok(PostToolAction::Modify(text.to_string())),
347 VmValue::Dict(map) => {
348 if let Some(result) = map.get("result") {
349 return Ok(PostToolAction::Modify(result.display()));
350 }
351 Ok(PostToolAction::Pass)
352 }
353 other => Err(VmError::Runtime(format!(
354 "PostToolUse hook must return nil, string, or {{result}}, got {}",
355 other.type_name()
356 ))),
357 }
358}
359
360pub async fn run_pre_tool_hooks(
362 tool_name: &str,
363 args: &serde_json::Value,
364) -> Result<PreToolAction, VmError> {
365 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
366 let mut current_args = args.clone();
367 for hook in hooks
368 .iter()
369 .filter(|hook| hook.event == HookEvent::PreToolUse)
370 {
371 let payload = serde_json::json!({
372 "event": HookEvent::PreToolUse.as_str(),
373 "tool": {
374 "name": tool_name,
375 "args": current_args.clone(),
376 },
377 });
378 if !hook_matches(hook, Some(tool_name), &payload) {
379 continue;
380 }
381 let action = match &hook.handler {
382 RuntimeHookHandler::NativePreTool(pre) => pre(tool_name, ¤t_args),
383 RuntimeHookHandler::Vm { closure, .. } => {
384 parse_pre_tool_result(invoke_vm_hook(closure, &payload).await?)?
385 }
386 RuntimeHookHandler::NativePostTool(_) => continue,
387 };
388 match action {
389 PreToolAction::Allow => {}
390 PreToolAction::Deny(reason) => return Ok(PreToolAction::Deny(reason)),
391 PreToolAction::Modify(new_args) => {
392 current_args = new_args;
393 }
394 }
395 }
396 if current_args != *args {
397 Ok(PreToolAction::Modify(current_args))
398 } else {
399 Ok(PreToolAction::Allow)
400 }
401}
402
403pub async fn run_post_tool_hooks(
405 tool_name: &str,
406 args: &serde_json::Value,
407 result: &str,
408) -> Result<String, VmError> {
409 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
410 let mut current = result.to_string();
411 for hook in hooks
412 .iter()
413 .filter(|hook| hook.event == HookEvent::PostToolUse)
414 {
415 let payload = serde_json::json!({
416 "event": HookEvent::PostToolUse.as_str(),
417 "tool": {
418 "name": tool_name,
419 "args": args,
420 },
421 "result": {
422 "text": current.clone(),
423 },
424 });
425 if !hook_matches(hook, Some(tool_name), &payload) {
426 continue;
427 }
428 let action = match &hook.handler {
429 RuntimeHookHandler::NativePostTool(post) => post(tool_name, ¤t),
430 RuntimeHookHandler::Vm { closure, .. } => {
431 parse_post_tool_result(invoke_vm_hook(closure, &payload).await?)?
432 }
433 RuntimeHookHandler::NativePreTool(_) => continue,
434 };
435 match action {
436 PostToolAction::Pass => {}
437 PostToolAction::Modify(new_result) => {
438 current = new_result;
439 }
440 }
441 }
442 Ok(current)
443}
444
445pub async fn run_lifecycle_hooks(
446 event: HookEvent,
447 payload: &serde_json::Value,
448) -> Result<(), VmError> {
449 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
450 for hook in hooks.iter().filter(|hook| hook.event == event) {
451 if !hook_matches(hook, None, payload) {
452 continue;
453 }
454 if let RuntimeHookHandler::Vm { closure, .. } = &hook.handler {
455 let _ = invoke_vm_hook(closure, payload).await?;
456 }
457 }
458 Ok(())
459}