1use crate::types::events::ThreadEvent;
53use crate::types::items::ThreadItem;
54use std::future::Future;
55use std::pin::Pin;
56use std::sync::Arc;
57use std::time::Duration;
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash)]
63pub enum HookEvent {
64 CommandStarted,
66 CommandCompleted,
68 CommandFailed,
70 FileChangeCompleted,
72 AgentMessage,
74 TurnCompleted,
76 TurnFailed,
78}
79
80#[derive(Debug, Clone)]
84pub struct HookInput {
85 pub hook_event: HookEvent,
87 pub command: Option<String>,
89 pub exit_code: Option<i32>,
91 pub message_text: Option<String>,
93 pub raw_event: ThreadEvent,
95}
96
97#[derive(Debug, Clone)]
99pub struct HookContext {
100 pub thread_id: Option<String>,
102 pub turn_count: u32,
104}
105
106#[derive(Debug, Clone)]
108pub struct HookOutput {
109 pub decision: HookDecision,
111 pub reason: Option<String>,
113 pub replacement_event: Option<ThreadEvent>,
115}
116
117impl Default for HookOutput {
118 fn default() -> Self {
119 Self {
120 decision: HookDecision::Allow,
121 reason: None,
122 replacement_event: None,
123 }
124 }
125}
126
127#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
129pub enum HookTimeoutBehavior {
130 #[default]
135 FailOpen,
136
137 FailClosed,
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
146pub enum HookDecision {
147 Allow,
149 Block,
151 Modify,
153 Abort,
155}
156
157pub type HookCallback = Arc<
159 dyn Fn(HookInput, HookContext) -> Pin<Box<dyn Future<Output = HookOutput> + Send>>
160 + Send
161 + Sync,
162>;
163
164#[derive(Clone)]
168pub struct HookMatcher {
169 pub event: HookEvent,
171 pub command_filter: Option<String>,
173 pub callback: HookCallback,
175 pub timeout: Option<Duration>,
177 pub on_timeout: HookTimeoutBehavior,
182}
183
184impl std::fmt::Debug for HookMatcher {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 f.debug_struct("HookMatcher")
187 .field("event", &self.event)
188 .field("command_filter", &self.command_filter)
189 .field("timeout", &self.timeout)
190 .field("on_timeout", &self.on_timeout)
191 .finish()
192 }
193}
194
195pub fn classify_hook_event(event: &ThreadEvent) -> Option<HookEvent> {
202 use crate::types::items::CommandExecutionStatus;
203
204 match event {
205 ThreadEvent::ItemStarted {
206 item: ThreadItem::CommandExecution { .. },
207 } => Some(HookEvent::CommandStarted),
208
209 ThreadEvent::ItemCompleted {
210 item: ThreadItem::CommandExecution { status, .. },
211 } => match status {
212 CommandExecutionStatus::Completed => Some(HookEvent::CommandCompleted),
213 CommandExecutionStatus::Failed => Some(HookEvent::CommandFailed),
214 CommandExecutionStatus::InProgress => None,
215 },
216
217 ThreadEvent::ItemCompleted {
218 item: ThreadItem::FileChange { .. },
219 } => Some(HookEvent::FileChangeCompleted),
220
221 ThreadEvent::ItemCompleted {
222 item: ThreadItem::AgentMessage { .. },
223 } => Some(HookEvent::AgentMessage),
224
225 ThreadEvent::TurnCompleted { .. } => Some(HookEvent::TurnCompleted),
226 ThreadEvent::TurnFailed { .. } => Some(HookEvent::TurnFailed),
227
228 _ => None,
229 }
230}
231
232pub fn build_hook_input(hook_event: HookEvent, event: &ThreadEvent) -> HookInput {
234 let (command, exit_code, message_text) = match event {
235 ThreadEvent::ItemStarted {
236 item: ThreadItem::CommandExecution { command, .. },
237 }
238 | ThreadEvent::ItemCompleted {
239 item: ThreadItem::CommandExecution { command, .. },
240 } => {
241 let exit_code = match event {
242 ThreadEvent::ItemCompleted {
243 item: ThreadItem::CommandExecution { exit_code, .. },
244 } => *exit_code,
245 _ => None,
246 };
247 (Some(command.clone()), exit_code, None)
248 }
249
250 ThreadEvent::ItemCompleted {
251 item: ThreadItem::AgentMessage { text, .. },
252 } => (None, None, Some(text.clone())),
253
254 ThreadEvent::TurnFailed { error } => (None, None, Some(error.message.clone())),
255
256 _ => (None, None, None),
257 };
258
259 HookInput {
260 hook_event,
261 command,
262 exit_code,
263 message_text,
264 raw_event: event.clone(),
265 }
266}
267
268pub async fn dispatch_hook(
275 event: &ThreadEvent,
276 hooks: &[HookMatcher],
277 context: &HookContext,
278 default_timeout: Duration,
279) -> Option<HookOutput> {
280 let hook_event = classify_hook_event(event)?;
281 let input = build_hook_input(hook_event.clone(), event);
282
283 for hook in hooks {
284 if hook.event != hook_event {
285 continue;
286 }
287
288 if let Some(ref filter) = hook.command_filter {
290 match &input.command {
291 Some(cmd) if cmd.contains(filter.as_str()) => {}
292 _ => continue,
293 }
294 }
295
296 let timeout = hook.timeout.unwrap_or(default_timeout);
297 let future = (hook.callback)(input.clone(), context.clone());
298
299 match tokio::time::timeout(timeout, future).await {
300 Ok(output) => return Some(output),
301 Err(_) => {
302 tracing::warn!(
303 "Hook timed out after {:?} for {:?} — {:?}",
304 timeout,
305 hook.event,
306 hook.on_timeout,
307 );
308 match hook.on_timeout {
309 HookTimeoutBehavior::FailOpen => continue,
310 HookTimeoutBehavior::FailClosed => {
311 return Some(HookOutput {
312 decision: HookDecision::Block,
313 reason: Some(format!("hook timeout after {timeout:?} (fail-closed)")),
314 replacement_event: None,
315 });
316 }
317 }
318 }
319 }
320 }
321
322 None
323}
324
325#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::types::events::Usage;
331
332 fn make_command_started(cmd: &str) -> ThreadEvent {
333 ThreadEvent::ItemStarted {
334 item: ThreadItem::CommandExecution {
335 id: "cmd-1".into(),
336 command: cmd.into(),
337 aggregated_output: String::new(),
338 exit_code: None,
339 status: crate::types::items::CommandExecutionStatus::InProgress,
340 },
341 }
342 }
343
344 fn make_command_completed(cmd: &str, code: i32) -> ThreadEvent {
345 ThreadEvent::ItemCompleted {
346 item: ThreadItem::CommandExecution {
347 id: "cmd-1".into(),
348 command: cmd.into(),
349 aggregated_output: "output".into(),
350 exit_code: Some(code),
351 status: crate::types::items::CommandExecutionStatus::Completed,
352 },
353 }
354 }
355
356 fn make_turn_completed() -> ThreadEvent {
357 ThreadEvent::TurnCompleted {
358 usage: Usage {
359 input_tokens: 100,
360 cached_input_tokens: 0,
361 output_tokens: 50,
362 },
363 }
364 }
365
366 fn make_context() -> HookContext {
367 HookContext {
368 thread_id: Some("thread-1".into()),
369 turn_count: 0,
370 }
371 }
372
373 #[test]
374 fn classify_command_started() {
375 let event = make_command_started("ls -la");
376 assert_eq!(classify_hook_event(&event), Some(HookEvent::CommandStarted));
377 }
378
379 #[test]
380 fn classify_command_completed() {
381 let event = make_command_completed("ls", 0);
382 assert_eq!(
383 classify_hook_event(&event),
384 Some(HookEvent::CommandCompleted)
385 );
386 }
387
388 #[test]
389 fn classify_turn_completed() {
390 let event = make_turn_completed();
391 assert_eq!(classify_hook_event(&event), Some(HookEvent::TurnCompleted));
392 }
393
394 #[test]
395 fn classify_unmatched_returns_none() {
396 let event = ThreadEvent::TurnStarted;
397 assert_eq!(classify_hook_event(&event), None);
398 }
399
400 #[test]
401 fn build_input_extracts_command() {
402 let event = make_command_started("git status");
403 let input = build_hook_input(HookEvent::CommandStarted, &event);
404 assert_eq!(input.command, Some("git status".into()));
405 assert_eq!(input.exit_code, None);
406 }
407
408 #[test]
409 fn build_input_extracts_exit_code() {
410 let event = make_command_completed("ls", 1);
411 let input = build_hook_input(HookEvent::CommandCompleted, &event);
412 assert_eq!(input.exit_code, Some(1));
413 }
414
415 #[tokio::test]
416 async fn dispatch_first_match() {
417 let hook = HookMatcher {
418 event: HookEvent::CommandStarted,
419 command_filter: None,
420 callback: Arc::new(|_input, _ctx| {
421 Box::pin(async {
422 HookOutput {
423 decision: HookDecision::Block,
424 reason: Some("blocked".into()),
425 replacement_event: None,
426 }
427 })
428 }),
429 timeout: None,
430 on_timeout: Default::default(),
431 };
432
433 let event = make_command_started("ls");
434 let ctx = make_context();
435 let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
436
437 assert!(result.is_some());
438 let output = result.unwrap();
439 assert_eq!(output.decision, HookDecision::Block);
440 }
441
442 #[tokio::test]
443 async fn dispatch_command_filter() {
444 let hook = HookMatcher {
445 event: HookEvent::CommandStarted,
446 command_filter: Some("rm".into()),
447 callback: Arc::new(|_input, _ctx| {
448 Box::pin(async {
449 HookOutput {
450 decision: HookDecision::Block,
451 reason: None,
452 replacement_event: None,
453 }
454 })
455 }),
456 timeout: None,
457 on_timeout: Default::default(),
458 };
459
460 let ctx = make_context();
461
462 let ls_event = make_command_started("ls -la");
464 let result = dispatch_hook(&ls_event, &[hook], &ctx, Duration::from_secs(5)).await;
465 assert!(result.is_none());
466 }
467
468 #[tokio::test]
469 async fn dispatch_command_filter_matches() {
470 let hook = HookMatcher {
471 event: HookEvent::CommandStarted,
472 command_filter: Some("rm".into()),
473 callback: Arc::new(|_input, _ctx| {
474 Box::pin(async {
475 HookOutput {
476 decision: HookDecision::Block,
477 reason: None,
478 replacement_event: None,
479 }
480 })
481 }),
482 timeout: None,
483 on_timeout: Default::default(),
484 };
485
486 let ctx = make_context();
487
488 let rm_event = make_command_started("rm -rf /tmp/test");
489 let result = dispatch_hook(&rm_event, &[hook], &ctx, Duration::from_secs(5)).await;
490 assert!(result.is_some());
491 assert_eq!(result.unwrap().decision, HookDecision::Block);
492 }
493
494 #[tokio::test]
495 async fn dispatch_no_match_returns_none() {
496 let hook = HookMatcher {
497 event: HookEvent::TurnCompleted,
498 command_filter: None,
499 callback: Arc::new(|_input, _ctx| Box::pin(async { HookOutput::default() })),
500 timeout: None,
501 on_timeout: Default::default(),
502 };
503
504 let event = make_command_started("ls");
505 let ctx = make_context();
506 let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
507 assert!(result.is_none());
508 }
509
510 #[tokio::test]
511 async fn dispatch_timeout_fails_open() {
512 let hook = HookMatcher {
513 event: HookEvent::CommandStarted,
514 command_filter: None,
515 callback: Arc::new(|_input, _ctx| {
516 Box::pin(async {
517 tokio::time::sleep(Duration::from_secs(10)).await;
519 HookOutput {
520 decision: HookDecision::Block,
521 reason: None,
522 replacement_event: None,
523 }
524 })
525 }),
526 timeout: Some(Duration::from_millis(10)),
527 on_timeout: HookTimeoutBehavior::FailOpen,
528 };
529
530 let event = make_command_started("ls");
531 let ctx = make_context();
532 let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
533
534 assert!(result.is_none());
536 }
537
538 #[tokio::test]
539 async fn dispatch_timeout_fail_closed_blocks() {
540 let hook = HookMatcher {
541 event: HookEvent::CommandStarted,
542 command_filter: None,
543 callback: Arc::new(|_input, _ctx| {
544 Box::pin(async {
545 tokio::time::sleep(Duration::from_secs(10)).await;
546 HookOutput::default()
547 })
548 }),
549 timeout: Some(Duration::from_millis(10)),
550 on_timeout: HookTimeoutBehavior::FailClosed,
551 };
552
553 let event = make_command_started("dangerous-cmd");
554 let ctx = make_context();
555 let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
556
557 assert!(result.is_some());
559 let output = result.unwrap();
560 assert_eq!(output.decision, HookDecision::Block);
561 assert!(output.reason.as_deref().unwrap_or("").contains("timeout"));
562 }
563
564 #[tokio::test]
565 async fn dispatch_all_four_decisions() {
566 for decision in [
567 HookDecision::Allow,
568 HookDecision::Block,
569 HookDecision::Modify,
570 HookDecision::Abort,
571 ] {
572 let d = decision.clone();
573 let hook = HookMatcher {
574 event: HookEvent::TurnCompleted,
575 command_filter: None,
576 callback: Arc::new(move |_input, _ctx| {
577 let d = d.clone();
578 Box::pin(async move {
579 HookOutput {
580 decision: d,
581 reason: None,
582 replacement_event: None,
583 }
584 })
585 }),
586 timeout: None,
587 on_timeout: Default::default(),
588 };
589
590 let event = make_turn_completed();
591 let ctx = make_context();
592 let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
593 assert!(result.is_some());
594 assert_eq!(result.unwrap().decision, decision);
595 }
596 }
597}