1use std::cell::RefCell;
4use std::rc::Rc;
5
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9use crate::agent_events::WorkerEvent;
10use crate::value::{VmClosure, VmError, VmValue};
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
14pub enum HookEvent {
15 #[serde(rename = "PreToolUse")]
16 PreToolUse,
17 #[serde(rename = "PostToolUse")]
18 PostToolUse,
19 #[serde(rename = "PreAgentTurn")]
20 PreAgentTurn,
21 #[serde(rename = "PostAgentTurn")]
22 PostAgentTurn,
23 #[serde(rename = "WorkerSpawned")]
24 WorkerSpawned,
25 #[serde(rename = "WorkerCompleted")]
26 WorkerCompleted,
27 #[serde(rename = "WorkerFailed")]
28 WorkerFailed,
29 #[serde(rename = "WorkerCancelled")]
30 WorkerCancelled,
31}
32
33impl HookEvent {
34 pub fn as_str(self) -> &'static str {
35 match self {
36 Self::PreToolUse => "PreToolUse",
37 Self::PostToolUse => "PostToolUse",
38 Self::PreAgentTurn => "PreAgentTurn",
39 Self::PostAgentTurn => "PostAgentTurn",
40 Self::WorkerSpawned => "WorkerSpawned",
41 Self::WorkerCompleted => "WorkerCompleted",
42 Self::WorkerFailed => "WorkerFailed",
43 Self::WorkerCancelled => "WorkerCancelled",
44 }
45 }
46
47 pub fn from_worker_event(event: WorkerEvent) -> Self {
48 match event {
49 WorkerEvent::WorkerSpawned => Self::WorkerSpawned,
50 WorkerEvent::WorkerCompleted => Self::WorkerCompleted,
51 WorkerEvent::WorkerFailed => Self::WorkerFailed,
52 WorkerEvent::WorkerCancelled => Self::WorkerCancelled,
53 }
54 }
55}
56
57#[derive(Clone, Debug)]
59pub enum PreToolAction {
60 Allow,
62 Deny(String),
64 Modify(serde_json::Value),
66}
67
68#[derive(Clone, Debug)]
70pub enum PostToolAction {
71 Pass,
73 Modify(String),
75}
76
77pub type PreToolHookFn = Rc<dyn Fn(&str, &serde_json::Value) -> PreToolAction>;
79pub type PostToolHookFn = Rc<dyn Fn(&str, &str) -> PostToolAction>;
80
81#[derive(Clone)]
83pub struct ToolHook {
84 pub pattern: String,
86 pub pre: Option<PreToolHookFn>,
88 pub post: Option<PostToolHookFn>,
90}
91
92impl std::fmt::Debug for ToolHook {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("ToolHook")
95 .field("pattern", &self.pattern)
96 .field("has_pre", &self.pre.is_some())
97 .field("has_post", &self.post.is_some())
98 .finish()
99 }
100}
101
102#[derive(Clone)]
103enum PatternMatcher {
104 ToolNameGlob(String),
105 EventExpression(String),
106}
107
108impl std::fmt::Debug for PatternMatcher {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match self {
111 Self::ToolNameGlob(pattern) => f.debug_tuple("ToolNameGlob").field(pattern).finish(),
112 Self::EventExpression(pattern) => {
113 f.debug_tuple("EventExpression").field(pattern).finish()
114 }
115 }
116 }
117}
118
119#[derive(Clone)]
120enum RuntimeHookHandler {
121 NativePreTool(PreToolHookFn),
122 NativePostTool(PostToolHookFn),
123 Vm {
124 handler_name: String,
125 closure: Rc<VmClosure>,
126 },
127}
128
129impl std::fmt::Debug for RuntimeHookHandler {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 match self {
132 Self::NativePreTool(_) => f.write_str("NativePreTool(..)"),
133 Self::NativePostTool(_) => f.write_str("NativePostTool(..)"),
134 Self::Vm { handler_name, .. } => f
135 .debug_struct("Vm")
136 .field("handler_name", handler_name)
137 .finish(),
138 }
139 }
140}
141
142#[derive(Clone, Debug)]
143struct RuntimeHook {
144 event: HookEvent,
145 matcher: PatternMatcher,
146 handler: RuntimeHookHandler,
147}
148
149thread_local! {
150 static RUNTIME_HOOKS: RefCell<Vec<RuntimeHook>> = const { RefCell::new(Vec::new()) };
151}
152
153pub(crate) fn glob_match(pattern: &str, name: &str) -> bool {
154 if pattern == "*" {
155 return true;
156 }
157 if let Some(prefix) = pattern.strip_suffix('*') {
158 return name.starts_with(prefix);
159 }
160 if let Some(suffix) = pattern.strip_prefix('*') {
161 return name.ends_with(suffix);
162 }
163 pattern == name
164}
165
166pub fn register_tool_hook(hook: ToolHook) {
167 if let Some(pre) = hook.pre {
168 RUNTIME_HOOKS.with(|hooks| {
169 hooks.borrow_mut().push(RuntimeHook {
170 event: HookEvent::PreToolUse,
171 matcher: PatternMatcher::ToolNameGlob(hook.pattern.clone()),
172 handler: RuntimeHookHandler::NativePreTool(pre),
173 });
174 });
175 }
176 if let Some(post) = hook.post {
177 RUNTIME_HOOKS.with(|hooks| {
178 hooks.borrow_mut().push(RuntimeHook {
179 event: HookEvent::PostToolUse,
180 matcher: PatternMatcher::ToolNameGlob(hook.pattern),
181 handler: RuntimeHookHandler::NativePostTool(post),
182 });
183 });
184 }
185}
186
187pub fn register_vm_hook(
188 event: HookEvent,
189 pattern: impl Into<String>,
190 handler_name: impl Into<String>,
191 closure: Rc<VmClosure>,
192) {
193 RUNTIME_HOOKS.with(|hooks| {
194 hooks.borrow_mut().push(RuntimeHook {
195 event,
196 matcher: PatternMatcher::EventExpression(pattern.into()),
197 handler: RuntimeHookHandler::Vm {
198 handler_name: handler_name.into(),
199 closure,
200 },
201 });
202 });
203}
204
205pub fn clear_tool_hooks() {
206 RUNTIME_HOOKS.with(|hooks| {
207 hooks
208 .borrow_mut()
209 .retain(|hook| !matches!(hook.event, HookEvent::PreToolUse | HookEvent::PostToolUse));
210 });
211}
212
213pub fn clear_runtime_hooks() {
214 RUNTIME_HOOKS.with(|hooks| hooks.borrow_mut().clear());
215}
216
217fn value_at_path<'a>(value: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
218 let mut current = value;
219 for segment in path.split('.') {
220 let serde_json::Value::Object(map) = current else {
221 return None;
222 };
223 current = map.get(segment)?;
224 }
225 Some(current)
226}
227
228fn value_truthy(value: &serde_json::Value) -> bool {
229 match value {
230 serde_json::Value::Null => false,
231 serde_json::Value::Bool(value) => *value,
232 serde_json::Value::Number(value) => value
233 .as_i64()
234 .map(|number| number != 0)
235 .or_else(|| value.as_u64().map(|number| number != 0))
236 .or_else(|| value.as_f64().map(|number| number != 0.0))
237 .unwrap_or(false),
238 serde_json::Value::String(value) => !value.is_empty(),
239 serde_json::Value::Array(values) => !values.is_empty(),
240 serde_json::Value::Object(values) => !values.is_empty(),
241 }
242}
243
244fn value_to_pattern_string(value: Option<&serde_json::Value>) -> String {
245 match value {
246 Some(serde_json::Value::String(text)) => text.clone(),
247 Some(other) => other.to_string(),
248 None => String::new(),
249 }
250}
251
252fn strip_quoted(value: &str) -> &str {
253 value
254 .trim()
255 .strip_prefix('"')
256 .and_then(|text| text.strip_suffix('"'))
257 .or_else(|| {
258 value
259 .trim()
260 .strip_prefix('\'')
261 .and_then(|text| text.strip_suffix('\''))
262 })
263 .unwrap_or(value.trim())
264}
265
266fn expression_matches(pattern: &str, payload: &serde_json::Value) -> bool {
267 let pattern = pattern.trim();
268 if pattern.is_empty() || pattern == "*" {
269 return true;
270 }
271 if let Some((lhs, rhs)) = pattern.split_once("=~") {
272 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
273 let regex = strip_quoted(rhs);
274 return Regex::new(regex).is_ok_and(|compiled| compiled.is_match(&value));
275 }
276 if let Some((lhs, rhs)) = pattern.split_once("==") {
277 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
278 return value == strip_quoted(rhs);
279 }
280 if let Some((lhs, rhs)) = pattern.split_once("!=") {
281 let value = value_to_pattern_string(value_at_path(payload, lhs.trim()));
282 return value != strip_quoted(rhs);
283 }
284 if pattern.contains('.') {
285 return value_at_path(payload, pattern).is_some_and(value_truthy);
286 }
287 glob_match(
288 pattern,
289 &value_to_pattern_string(value_at_path(payload, "tool.name")),
290 )
291}
292
293fn hook_matches(hook: &RuntimeHook, tool_name: Option<&str>, payload: &serde_json::Value) -> bool {
294 match &hook.matcher {
295 PatternMatcher::ToolNameGlob(pattern) => {
296 tool_name.is_some_and(|candidate| glob_match(pattern, candidate))
297 }
298 PatternMatcher::EventExpression(pattern) => expression_matches(pattern, payload),
299 }
300}
301
302async fn invoke_vm_hook(
303 closure: &Rc<VmClosure>,
304 payload: &serde_json::Value,
305) -> Result<VmValue, VmError> {
306 let Some(mut vm) = crate::vm::clone_async_builtin_child_vm() else {
307 return Err(VmError::Runtime(
308 "runtime hook requires an async builtin VM context".to_string(),
309 ));
310 };
311 let arg = crate::stdlib::json_to_vm_value(payload);
312 vm.call_closure_pub(closure, &[arg]).await
313}
314
315fn parse_pre_tool_result(value: VmValue) -> Result<PreToolAction, VmError> {
316 match value {
317 VmValue::Nil => Ok(PreToolAction::Allow),
318 VmValue::Dict(map) => {
319 if let Some(reason) = map.get("deny") {
320 return Ok(PreToolAction::Deny(reason.display()));
321 }
322 if let Some(args) = map.get("args") {
323 return Ok(PreToolAction::Modify(crate::llm::vm_value_to_json(args)));
324 }
325 Ok(PreToolAction::Allow)
326 }
327 other => Err(VmError::Runtime(format!(
328 "PreToolUse hook must return nil or {{deny, args}}, got {}",
329 other.type_name()
330 ))),
331 }
332}
333
334fn parse_post_tool_result(value: VmValue) -> Result<PostToolAction, VmError> {
335 match value {
336 VmValue::Nil => Ok(PostToolAction::Pass),
337 VmValue::String(text) => Ok(PostToolAction::Modify(text.to_string())),
338 VmValue::Dict(map) => {
339 if let Some(result) = map.get("result") {
340 return Ok(PostToolAction::Modify(result.display()));
341 }
342 Ok(PostToolAction::Pass)
343 }
344 other => Err(VmError::Runtime(format!(
345 "PostToolUse hook must return nil, string, or {{result}}, got {}",
346 other.type_name()
347 ))),
348 }
349}
350
351pub async fn run_pre_tool_hooks(
353 tool_name: &str,
354 args: &serde_json::Value,
355) -> Result<PreToolAction, VmError> {
356 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
357 let mut current_args = args.clone();
358 for hook in hooks
359 .iter()
360 .filter(|hook| hook.event == HookEvent::PreToolUse)
361 {
362 let payload = serde_json::json!({
363 "event": HookEvent::PreToolUse.as_str(),
364 "tool": {
365 "name": tool_name,
366 "args": current_args.clone(),
367 },
368 });
369 if !hook_matches(hook, Some(tool_name), &payload) {
370 continue;
371 }
372 let action = match &hook.handler {
373 RuntimeHookHandler::NativePreTool(pre) => pre(tool_name, ¤t_args),
374 RuntimeHookHandler::Vm { closure, .. } => {
375 parse_pre_tool_result(invoke_vm_hook(closure, &payload).await?)?
376 }
377 RuntimeHookHandler::NativePostTool(_) => continue,
378 };
379 match action {
380 PreToolAction::Allow => {}
381 PreToolAction::Deny(reason) => return Ok(PreToolAction::Deny(reason)),
382 PreToolAction::Modify(new_args) => {
383 current_args = new_args;
384 }
385 }
386 }
387 if current_args != *args {
388 Ok(PreToolAction::Modify(current_args))
389 } else {
390 Ok(PreToolAction::Allow)
391 }
392}
393
394pub async fn run_post_tool_hooks(
396 tool_name: &str,
397 args: &serde_json::Value,
398 result: &str,
399) -> Result<String, VmError> {
400 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
401 let mut current = result.to_string();
402 for hook in hooks
403 .iter()
404 .filter(|hook| hook.event == HookEvent::PostToolUse)
405 {
406 let payload = serde_json::json!({
407 "event": HookEvent::PostToolUse.as_str(),
408 "tool": {
409 "name": tool_name,
410 "args": args,
411 },
412 "result": {
413 "text": current.clone(),
414 },
415 });
416 if !hook_matches(hook, Some(tool_name), &payload) {
417 continue;
418 }
419 let action = match &hook.handler {
420 RuntimeHookHandler::NativePostTool(post) => post(tool_name, ¤t),
421 RuntimeHookHandler::Vm { closure, .. } => {
422 parse_post_tool_result(invoke_vm_hook(closure, &payload).await?)?
423 }
424 RuntimeHookHandler::NativePreTool(_) => continue,
425 };
426 match action {
427 PostToolAction::Pass => {}
428 PostToolAction::Modify(new_result) => {
429 current = new_result;
430 }
431 }
432 }
433 Ok(current)
434}
435
436pub async fn run_lifecycle_hooks(
437 event: HookEvent,
438 payload: &serde_json::Value,
439) -> Result<(), VmError> {
440 let hooks = RUNTIME_HOOKS.with(|hooks| hooks.borrow().clone());
441 for hook in hooks.iter().filter(|hook| hook.event == event) {
442 if !hook_matches(hook, None, payload) {
443 continue;
444 }
445 if let RuntimeHookHandler::Vm { closure, .. } = &hook.handler {
446 let _ = invoke_vm_hook(closure, payload).await?;
447 }
448 }
449 Ok(())
450}