use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
PreToolUse,
PostToolUse,
PostToolUseFailure,
UserPromptSubmit,
Stop,
SubagentStop,
PreCompact,
Notification,
}
impl HookEvent {
pub fn is_supported(&self) -> bool {
matches!(
self,
HookEvent::PreToolUse
| HookEvent::PostToolUse
| HookEvent::PostToolUseFailure
| HookEvent::UserPromptSubmit
| HookEvent::Stop
)
}
}
#[derive(Clone)]
pub struct HookMatcher {
pub event: HookEvent,
pub tool_name: Option<String>,
pub callback: HookCallback,
pub timeout: Option<Duration>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookInput {
pub event: HookEvent,
pub tool_name: Option<String>,
pub tool_input: Option<Value>,
pub tool_output: Option<Value>,
pub prompt: Option<String>,
pub session_id: String,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone)]
pub struct HookContext {
pub session_id: String,
pub cwd: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookOutput {
pub decision: HookDecision,
#[serde(default)]
pub updated_input: Option<Value>,
#[serde(default)]
pub message: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookDecision {
Continue,
Block,
Skip,
}
impl Default for HookOutput {
fn default() -> Self {
Self {
decision: HookDecision::Continue,
updated_input: None,
message: None,
}
}
}
pub type HookCallback = Arc<
dyn Fn(HookInput, HookContext) -> Pin<Box<dyn Future<Output = HookOutput> + Send>>
+ Send
+ Sync,
>;
pub(crate) async fn execute_hooks(
hooks: &[HookMatcher],
input: HookInput,
context: &HookContext,
default_timeout: Duration,
) -> HookOutput {
for hook in hooks {
if hook.event != input.event {
continue;
}
if let Some(pattern) = &hook.tool_name {
if let Some(tool_name) = &input.tool_name {
if !tool_name_matches(tool_name, pattern) {
continue;
}
} else {
continue;
}
}
let timeout = hook.timeout.unwrap_or(default_timeout);
let result = tokio::time::timeout(
timeout,
(hook.callback)(input.clone(), context.clone()),
)
.await;
match result {
Ok(output) => {
if output.decision != HookDecision::Skip {
return output;
}
}
Err(_) => {
tracing::warn!("Hook timed out for event {:?}", input.event);
}
}
}
HookOutput::default()
}
#[allow(dead_code)]
fn tool_name_matches(name: &str, pattern: &str) -> bool {
if pattern.ends_with('*') {
name.starts_with(pattern.strip_suffix('*').unwrap_or(pattern))
} else {
name == pattern
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_event_is_supported() {
assert!(HookEvent::PreToolUse.is_supported());
assert!(HookEvent::PostToolUse.is_supported());
assert!(HookEvent::PostToolUseFailure.is_supported());
assert!(HookEvent::UserPromptSubmit.is_supported());
assert!(HookEvent::Stop.is_supported());
assert!(!HookEvent::SubagentStop.is_supported());
assert!(!HookEvent::PreCompact.is_supported());
assert!(!HookEvent::Notification.is_supported());
}
#[test]
fn test_tool_name_exact_match() {
assert!(tool_name_matches("EditFile", "EditFile"));
assert!(!tool_name_matches("EditFile", "ReadFile"));
}
#[test]
fn test_tool_name_glob_match() {
assert!(tool_name_matches("EditFile", "Edit*"));
assert!(tool_name_matches("EditBlock", "Edit*"));
assert!(!tool_name_matches("ReadFile", "Edit*"));
}
fn make_input(event: HookEvent) -> HookInput {
HookInput {
event,
tool_name: None,
tool_input: None,
tool_output: None,
prompt: None,
session_id: "test-session".to_string(),
extra: serde_json::Value::Null,
}
}
fn make_context() -> HookContext {
HookContext {
session_id: "test-session".to_string(),
cwd: "/tmp".to_string(),
}
}
#[tokio::test]
async fn test_execute_hooks_no_match() {
let hooks = vec![];
let input = make_input(HookEvent::PreToolUse);
let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
assert_eq!(output.decision, HookDecision::Continue);
}
#[tokio::test]
async fn test_execute_hooks_matching() {
let hooks = vec![HookMatcher {
event: HookEvent::PreToolUse,
tool_name: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
updated_input: None,
message: Some("blocked".to_string()),
}
})
}),
timeout: None,
}];
let input = make_input(HookEvent::PreToolUse);
let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
assert_eq!(output.decision, HookDecision::Block);
}
#[tokio::test]
async fn test_execute_hooks_wrong_event() {
let hooks = vec![HookMatcher {
event: HookEvent::PostToolUse,
tool_name: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
updated_input: None,
message: None,
}
})
}),
timeout: None,
}];
let input = make_input(HookEvent::PreToolUse);
let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
assert_eq!(output.decision, HookDecision::Continue);
}
#[tokio::test]
async fn test_execute_hooks_tool_name_filter() {
let hooks = vec![HookMatcher {
event: HookEvent::PreToolUse,
tool_name: Some("EditFile".to_string()),
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
updated_input: None,
message: None,
}
})
}),
timeout: None,
}];
let mut input = make_input(HookEvent::PreToolUse);
input.tool_name = Some("EditFile".to_string());
let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
assert_eq!(output.decision, HookDecision::Block);
let mut input2 = make_input(HookEvent::PreToolUse);
input2.tool_name = Some("ReadFile".to_string());
let output2 = execute_hooks(&hooks, input2, &make_context(), Duration::from_secs(5)).await;
assert_eq!(output2.decision, HookDecision::Continue);
}
#[tokio::test]
async fn test_execute_hooks_glob_filter() {
let hooks = vec![HookMatcher {
event: HookEvent::PreToolUse,
tool_name: Some("Edit*".to_string()),
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
updated_input: None,
message: None,
}
})
}),
timeout: None,
}];
let mut input = make_input(HookEvent::PreToolUse);
input.tool_name = Some("EditBlock".to_string());
let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
assert_eq!(output.decision, HookDecision::Block);
}
#[tokio::test]
async fn test_execute_hooks_timeout() {
let hooks = vec![HookMatcher {
event: HookEvent::PreToolUse,
tool_name: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
tokio::time::sleep(Duration::from_secs(10)).await;
HookOutput::default()
})
}),
timeout: Some(Duration::from_millis(10)),
}];
let input = make_input(HookEvent::PreToolUse);
let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
assert_eq!(output.decision, HookDecision::Continue);
}
#[tokio::test]
async fn test_execute_hooks_skip_advances() {
let hooks = vec![
HookMatcher {
event: HookEvent::PreToolUse,
tool_name: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Skip,
updated_input: None,
message: None,
}
})
}),
timeout: None,
},
HookMatcher {
event: HookEvent::PreToolUse,
tool_name: None,
callback: Arc::new(|_input, _ctx| {
Box::pin(async {
HookOutput {
decision: HookDecision::Block,
updated_input: None,
message: None,
}
})
}),
timeout: None,
},
];
let input = make_input(HookEvent::PreToolUse);
let output = execute_hooks(&hooks, input, &make_context(), Duration::from_secs(5)).await;
assert_eq!(output.decision, HookDecision::Block);
}
}