use crate::types::events::ThreadEvent;
use crate::types::items::ThreadItem;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum HookEvent {
CommandStarted,
CommandCompleted,
CommandFailed,
FileChangeCompleted,
AgentMessage,
TurnCompleted,
TurnFailed,
}
#[derive(Debug, Clone)]
pub struct HookInput {
pub hook_event: HookEvent,
pub command: Option<String>,
pub exit_code: Option<i32>,
pub message_text: Option<String>,
pub raw_event: ThreadEvent,
}
#[derive(Debug, Clone)]
pub struct HookContext {
pub thread_id: Option<String>,
pub turn_count: u32,
}
#[derive(Debug, Clone)]
pub struct HookOutput {
pub decision: HookDecision,
pub reason: Option<String>,
pub replacement_event: Option<ThreadEvent>,
}
impl Default for HookOutput {
fn default() -> Self {
Self {
decision: HookDecision::Allow,
reason: None,
replacement_event: None,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum HookTimeoutBehavior {
#[default]
FailOpen,
FailClosed,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookDecision {
Allow,
Block,
Modify,
Abort,
}
pub type HookCallback = Arc<
dyn Fn(HookInput, HookContext) -> Pin<Box<dyn Future<Output = HookOutput> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct HookMatcher {
pub event: HookEvent,
pub command_filter: Option<String>,
pub callback: HookCallback,
pub timeout: Option<Duration>,
pub on_timeout: HookTimeoutBehavior,
}
impl std::fmt::Debug for HookMatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HookMatcher")
.field("event", &self.event)
.field("command_filter", &self.command_filter)
.field("timeout", &self.timeout)
.field("on_timeout", &self.on_timeout)
.finish()
}
}
pub fn classify_hook_event(event: &ThreadEvent) -> Option<HookEvent> {
use crate::types::items::CommandExecutionStatus;
match event {
ThreadEvent::ItemStarted {
item: ThreadItem::CommandExecution { .. },
} => Some(HookEvent::CommandStarted),
ThreadEvent::ItemCompleted {
item: ThreadItem::CommandExecution { status, .. },
} => match status {
CommandExecutionStatus::Completed => Some(HookEvent::CommandCompleted),
CommandExecutionStatus::Failed => Some(HookEvent::CommandFailed),
CommandExecutionStatus::InProgress => None,
},
ThreadEvent::ItemCompleted {
item: ThreadItem::FileChange { .. },
} => Some(HookEvent::FileChangeCompleted),
ThreadEvent::ItemCompleted {
item: ThreadItem::AgentMessage { .. },
} => Some(HookEvent::AgentMessage),
ThreadEvent::TurnCompleted { .. } => Some(HookEvent::TurnCompleted),
ThreadEvent::TurnFailed { .. } => Some(HookEvent::TurnFailed),
_ => None,
}
}
pub fn build_hook_input(hook_event: HookEvent, event: &ThreadEvent) -> HookInput {
let (command, exit_code, message_text) = match event {
ThreadEvent::ItemStarted {
item: ThreadItem::CommandExecution { command, .. },
}
| ThreadEvent::ItemCompleted {
item: ThreadItem::CommandExecution { command, .. },
} => {
let exit_code = match event {
ThreadEvent::ItemCompleted {
item: ThreadItem::CommandExecution { exit_code, .. },
} => *exit_code,
_ => None,
};
(Some(command.clone()), exit_code, None)
}
ThreadEvent::ItemCompleted {
item: ThreadItem::AgentMessage { text, .. },
} => (None, None, Some(text.clone())),
ThreadEvent::TurnFailed { error } => (None, None, Some(error.message.clone())),
_ => (None, None, None),
};
HookInput {
hook_event,
command,
exit_code,
message_text,
raw_event: event.clone(),
}
}
pub async fn dispatch_hook(
event: &ThreadEvent,
hooks: &[HookMatcher],
context: &HookContext,
default_timeout: Duration,
) -> Option<HookOutput> {
let hook_event = classify_hook_event(event)?;
let input = build_hook_input(hook_event.clone(), event);
for hook in hooks {
if hook.event != hook_event {
continue;
}
if let Some(ref filter) = hook.command_filter {
match &input.command {
Some(cmd) if cmd.contains(filter.as_str()) => {}
_ => continue,
}
}
let timeout = hook.timeout.unwrap_or(default_timeout);
let future = (hook.callback)(input.clone(), context.clone());
match tokio::time::timeout(timeout, future).await {
Ok(output) => return Some(output),
Err(_) => {
tracing::warn!(
"Hook timed out after {:?} for {:?} — {:?}",
timeout,
hook.event,
hook.on_timeout,
);
match hook.on_timeout {
HookTimeoutBehavior::FailOpen => continue,
HookTimeoutBehavior::FailClosed => {
return Some(HookOutput {
decision: HookDecision::Block,
reason: Some(format!("hook timeout after {timeout:?} (fail-closed)")),
replacement_event: None,
});
}
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::events::Usage;
fn make_command_started(cmd: &str) -> ThreadEvent {
ThreadEvent::ItemStarted {
item: ThreadItem::CommandExecution {
id: "cmd-1".into(),
command: cmd.into(),
aggregated_output: String::new(),
exit_code: None,
status: crate::types::items::CommandExecutionStatus::InProgress,
},
}
}
fn make_command_completed(cmd: &str, code: i32) -> ThreadEvent {
ThreadEvent::ItemCompleted {
item: ThreadItem::CommandExecution {
id: "cmd-1".into(),
command: cmd.into(),
aggregated_output: "output".into(),
exit_code: Some(code),
status: crate::types::items::CommandExecutionStatus::Completed,
},
}
}
fn make_turn_completed() -> ThreadEvent {
ThreadEvent::TurnCompleted {
usage: Usage {
input_tokens: 100,
cached_input_tokens: 0,
output_tokens: 50,
},
}
}
fn make_context() -> HookContext {
HookContext {
thread_id: Some("thread-1".into()),
turn_count: 0,
}
}
#[test]
fn classify_command_started() {
let event = make_command_started("ls -la");
assert_eq!(classify_hook_event(&event), Some(HookEvent::CommandStarted));
}
#[test]
fn classify_command_completed() {
let event = make_command_completed("ls", 0);
assert_eq!(
classify_hook_event(&event),
Some(HookEvent::CommandCompleted)
);
}
#[test]
fn classify_turn_completed() {
let event = make_turn_completed();
assert_eq!(classify_hook_event(&event), Some(HookEvent::TurnCompleted));
}
#[test]
fn classify_unmatched_returns_none() {
let event = ThreadEvent::TurnStarted;
assert_eq!(classify_hook_event(&event), None);
}
#[test]
fn build_input_extracts_command() {
let event = make_command_started("git status");
let input = build_hook_input(HookEvent::CommandStarted, &event);
assert_eq!(input.command, Some("git status".into()));
assert_eq!(input.exit_code, None);
}
#[test]
fn build_input_extracts_exit_code() {
let event = make_command_completed("ls", 1);
let input = build_hook_input(HookEvent::CommandCompleted, &event);
assert_eq!(input.exit_code, Some(1));
}
#[tokio::test]
async fn dispatch_first_match() {
let hook = HookMatcher {
event: HookEvent::CommandStarted,
command_filter: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
reason: Some("blocked".into()),
replacement_event: None,
}
})
}),
timeout: None,
on_timeout: Default::default(),
};
let event = make_command_started("ls");
let ctx = make_context();
let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
assert!(result.is_some());
let output = result.unwrap();
assert_eq!(output.decision, HookDecision::Block);
}
#[tokio::test]
async fn dispatch_command_filter() {
let hook = HookMatcher {
event: HookEvent::CommandStarted,
command_filter: Some("rm".into()),
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
reason: None,
replacement_event: None,
}
})
}),
timeout: None,
on_timeout: Default::default(),
};
let ctx = make_context();
let ls_event = make_command_started("ls -la");
let result = dispatch_hook(&ls_event, &[hook], &ctx, Duration::from_secs(5)).await;
assert!(result.is_none());
}
#[tokio::test]
async fn dispatch_command_filter_matches() {
let hook = HookMatcher {
event: HookEvent::CommandStarted,
command_filter: Some("rm".into()),
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
reason: None,
replacement_event: None,
}
})
}),
timeout: None,
on_timeout: Default::default(),
};
let ctx = make_context();
let rm_event = make_command_started("rm -rf /tmp/test");
let result = dispatch_hook(&rm_event, &[hook], &ctx, Duration::from_secs(5)).await;
assert!(result.is_some());
assert_eq!(result.unwrap().decision, HookDecision::Block);
}
#[tokio::test]
async fn dispatch_no_match_returns_none() {
let hook = HookMatcher {
event: HookEvent::TurnCompleted,
command_filter: None,
callback: Arc::new(|_input, _ctx| Box::pin(async { HookOutput::default() })),
timeout: None,
on_timeout: Default::default(),
};
let event = make_command_started("ls");
let ctx = make_context();
let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
assert!(result.is_none());
}
#[tokio::test]
async fn dispatch_timeout_fails_open() {
let hook = HookMatcher {
event: HookEvent::CommandStarted,
command_filter: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
tokio::time::sleep(Duration::from_secs(10)).await;
HookOutput {
decision: HookDecision::Block,
reason: None,
replacement_event: None,
}
})
}),
timeout: Some(Duration::from_millis(10)),
on_timeout: HookTimeoutBehavior::FailOpen,
};
let event = make_command_started("ls");
let ctx = make_context();
let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
assert!(result.is_none());
}
#[tokio::test]
async fn dispatch_timeout_fail_closed_blocks() {
let hook = HookMatcher {
event: HookEvent::CommandStarted,
command_filter: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
tokio::time::sleep(Duration::from_secs(10)).await;
HookOutput::default()
})
}),
timeout: Some(Duration::from_millis(10)),
on_timeout: HookTimeoutBehavior::FailClosed,
};
let event = make_command_started("dangerous-cmd");
let ctx = make_context();
let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
assert!(result.is_some());
let output = result.unwrap();
assert_eq!(output.decision, HookDecision::Block);
assert!(output.reason.as_deref().unwrap_or("").contains("timeout"));
}
#[tokio::test]
async fn dispatch_all_four_decisions() {
for decision in [
HookDecision::Allow,
HookDecision::Block,
HookDecision::Modify,
HookDecision::Abort,
] {
let d = decision.clone();
let hook = HookMatcher {
event: HookEvent::TurnCompleted,
command_filter: None,
callback: Arc::new(move |_input, _ctx| {
let d = d.clone();
Box::pin(async move {
HookOutput {
decision: d,
reason: None,
replacement_event: None,
}
})
}),
timeout: None,
on_timeout: Default::default(),
};
let event = make_turn_completed();
let ctx = make_context();
let result = dispatch_hook(&event, &[hook], &ctx, Duration::from_secs(5)).await;
assert!(result.is_some());
assert_eq!(result.unwrap().decision, decision);
}
}
}