use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
PreToolUse,
PostToolUse,
PostToolUseFailure,
UserPromptSubmit,
Stop,
SubagentStop,
PreCompact,
Notification,
}
pub struct HookMatcher {
pub event: HookEvent,
pub tool_name: Option<String>,
pub callback: HookCallback,
pub timeout: Option<Duration>,
}
impl std::fmt::Debug for HookMatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HookMatcher")
.field("event", &self.event)
.field("tool_name", &self.tool_name)
.field("timeout", &self.timeout)
.finish()
}
}
impl HookMatcher {
pub fn new(event: HookEvent, callback: HookCallback) -> Self {
Self {
event,
tool_name: None,
callback,
timeout: None,
}
}
#[must_use]
pub fn for_tool(mut self, name: impl Into<String>) -> Self {
self.tool_name = Some(name.into());
self
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn matches(&self, event: HookEvent, tool_name: Option<&str>) -> bool {
if self.event != event {
return false;
}
match (&self.tool_name, tool_name) {
(Some(filter), Some(name)) => filter == name,
(Some(_), None) => false,
(None, _) => true,
}
}
}
use crate::util::BoxFuture;
pub type HookCallback =
Arc<dyn Fn(HookInput, Option<String>, HookContext) -> BoxFuture<HookOutput> + Send + Sync>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookInput {
pub hook_event: HookEvent,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_input: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_result: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_use_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct HookContext {
pub session_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct HookOutput {
pub decision: HookDecision,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub updated_input: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookDecision {
Allow,
Block,
Modify,
Abort,
}
impl HookOutput {
#[must_use]
pub fn allow() -> Self {
Self {
decision: HookDecision::Allow,
reason: None,
updated_input: None,
extra: None,
}
}
#[must_use]
pub fn block(reason: impl Into<String>) -> Self {
Self {
decision: HookDecision::Block,
reason: Some(reason.into()),
updated_input: None,
extra: None,
}
}
#[must_use]
pub fn modify(updated_input: serde_json::Value) -> Self {
Self {
decision: HookDecision::Modify,
reason: None,
updated_input: Some(updated_input),
extra: None,
}
}
#[must_use]
pub fn abort(reason: impl Into<String>) -> Self {
Self {
decision: HookDecision::Abort,
reason: Some(reason.into()),
updated_input: None,
extra: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub(crate) struct HookRequest {
pub request_id: String,
pub hook_event: HookEvent,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_input: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_result: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_use_id: Option<String>,
}
impl HookRequest {
#[cfg(test)]
pub fn into_hook_input(self) -> HookInput {
HookInput {
hook_event: self.hook_event,
tool_name: self.tool_name,
tool_input: self.tool_input,
tool_result: self.tool_result,
tool_use_id: self.tool_use_id,
extra: None,
}
}
pub(crate) fn to_hook_input(&self) -> HookInput {
HookInput {
hook_event: self.hook_event,
tool_name: self.tool_name.clone(),
tool_input: self.tool_input.clone(),
tool_result: self.tool_result.clone(),
tool_use_id: self.tool_use_id.clone(),
extra: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub(crate) struct HookResponse {
pub kind: String,
pub request_id: String,
pub result: HookOutput,
}
impl HookResponse {
pub fn from_output(request_id: String, output: HookOutput) -> Self {
Self {
kind: "hook_response".into(),
request_id,
result: output,
}
}
}
pub(crate) async fn dispatch_hook(
req: &HookRequest,
hooks: &[HookMatcher],
default_hook_timeout: Duration,
session_id: Option<String>,
) -> HookOutput {
let input = req.to_hook_input();
for matcher in hooks {
if !matcher.matches(req.hook_event, req.tool_name.as_deref()) {
continue;
}
let effective_timeout = matcher.timeout.unwrap_or(default_hook_timeout);
let ctx = HookContext {
session_id: session_id.clone(),
};
let fut = (matcher.callback)(input.clone(), session_id.clone(), ctx);
match tokio::time::timeout(effective_timeout, fut).await {
Ok(output) => return output,
Err(_) => {
tracing::warn!(
event = ?req.hook_event,
tool = ?req.tool_name,
timeout_secs = effective_timeout.as_secs_f64(),
"hook callback timed out, defaulting to allow (fail-open)"
);
return HookOutput::allow();
}
}
}
HookOutput::allow()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hook_event_round_trip() {
let events = [
HookEvent::PreToolUse,
HookEvent::PostToolUse,
HookEvent::PostToolUseFailure,
HookEvent::UserPromptSubmit,
HookEvent::Stop,
HookEvent::SubagentStop,
HookEvent::PreCompact,
HookEvent::Notification,
];
for event in events {
let json = serde_json::to_string(&event).unwrap();
let decoded: HookEvent = serde_json::from_str(&json).unwrap();
assert_eq!(event, decoded, "round-trip failed for {event:?}");
}
}
#[test]
fn hook_matcher_matches_any_tool() {
let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
let matcher = HookMatcher::new(HookEvent::PreToolUse, cb);
assert!(matcher.matches(HookEvent::PreToolUse, Some("bash")));
assert!(matcher.matches(HookEvent::PreToolUse, Some("read_file")));
assert!(matcher.matches(HookEvent::PreToolUse, None));
assert!(!matcher.matches(HookEvent::PostToolUse, Some("bash")));
}
#[test]
fn hook_matcher_matches_specific_tool() {
let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
let matcher = HookMatcher::new(HookEvent::PreToolUse, cb).for_tool("bash");
assert!(matcher.matches(HookEvent::PreToolUse, Some("bash")));
assert!(!matcher.matches(HookEvent::PreToolUse, Some("read_file")));
assert!(!matcher.matches(HookEvent::PreToolUse, None));
}
#[test]
fn hook_matcher_with_timeout() {
let cb: HookCallback = Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
let matcher = HookMatcher::new(HookEvent::Stop, cb).with_timeout(Duration::from_secs(5));
assert_eq!(matcher.timeout, Some(Duration::from_secs(5)));
}
#[test]
fn hook_output_allow() {
let output = HookOutput::allow();
assert_eq!(output.decision, HookDecision::Allow);
assert!(output.reason.is_none());
}
#[test]
fn hook_output_block() {
let output = HookOutput::block("dangerous command");
assert_eq!(output.decision, HookDecision::Block);
assert_eq!(output.reason.as_deref(), Some("dangerous command"));
}
#[test]
fn hook_output_modify() {
let output = HookOutput::modify(serde_json::json!({"safe": true}));
assert_eq!(output.decision, HookDecision::Modify);
assert!(output.updated_input.is_some());
}
#[test]
fn hook_output_abort() {
let output = HookOutput::abort("critical failure");
assert_eq!(output.decision, HookDecision::Abort);
assert_eq!(output.reason.as_deref(), Some("critical failure"));
}
#[test]
fn hook_output_round_trip() {
let output = HookOutput {
decision: HookDecision::Modify,
reason: Some("safety".into()),
updated_input: Some(serde_json::json!({"command": "ls"})),
extra: None,
};
let json = serde_json::to_string(&output).unwrap();
let decoded: HookOutput = serde_json::from_str(&json).unwrap();
assert_eq!(output.decision, decoded.decision);
assert_eq!(output.reason, decoded.reason);
assert_eq!(output.updated_input, decoded.updated_input);
}
#[test]
fn hook_request_round_trip() {
let req = HookRequest {
request_id: "hr-1".into(),
hook_event: HookEvent::PreToolUse,
tool_name: Some("bash".into()),
tool_input: Some(serde_json::json!({"command": "echo hello"})),
tool_result: None,
tool_use_id: Some("tu-1".into()),
};
let json = serde_json::to_string(&req).unwrap();
let decoded: HookRequest = serde_json::from_str(&json).unwrap();
assert_eq!(req, decoded);
}
#[test]
fn hook_request_into_hook_input() {
let req = HookRequest {
request_id: "hr-1".into(),
hook_event: HookEvent::PostToolUse,
tool_name: Some("bash".into()),
tool_input: None,
tool_result: Some(serde_json::json!("output")),
tool_use_id: Some("tu-1".into()),
};
let input = req.into_hook_input();
assert_eq!(input.hook_event, HookEvent::PostToolUse);
assert_eq!(input.tool_name.as_deref(), Some("bash"));
assert!(input.tool_result.is_some());
}
#[test]
fn hook_response_from_output() {
let output = HookOutput::allow();
let resp = HookResponse::from_output("req-1".into(), output);
assert_eq!(resp.kind, "hook_response");
assert_eq!(resp.request_id, "req-1");
assert_eq!(resp.result.decision, HookDecision::Allow);
}
#[test]
fn hook_response_round_trip() {
let resp = HookResponse {
kind: "hook_response".into(),
request_id: "hr-1".into(),
result: HookOutput::block("no"),
};
let json = serde_json::to_string(&resp).unwrap();
let decoded: HookResponse = serde_json::from_str(&json).unwrap();
assert_eq!(resp, decoded);
}
#[test]
fn hook_decision_serde() {
let decisions = [
(HookDecision::Allow, r#""allow""#),
(HookDecision::Block, r#""block""#),
(HookDecision::Modify, r#""modify""#),
(HookDecision::Abort, r#""abort""#),
];
for (decision, expected_json) in decisions {
let json = serde_json::to_string(&decision).unwrap();
assert_eq!(json, expected_json);
let decoded: HookDecision = serde_json::from_str(&json).unwrap();
assert_eq!(decision, decoded);
}
}
#[test]
fn hook_input_optional_fields() {
let json = r#"{"hook_event":"stop"}"#;
let input: HookInput = serde_json::from_str(json).unwrap();
assert_eq!(input.hook_event, HookEvent::Stop);
assert!(input.tool_name.is_none());
assert!(input.tool_input.is_none());
assert!(input.tool_result.is_none());
}
#[tokio::test]
async fn hook_timeout_defaults_to_config_value() {
let cb: HookCallback =
Arc::new(|_, _, _| Box::pin(async { HookOutput::block("should arrive") }));
let matchers = vec![HookMatcher::new(HookEvent::PreToolUse, cb)];
let req = HookRequest {
request_id: "r1".into(),
hook_event: HookEvent::PreToolUse,
tool_name: Some("Bash".into()),
tool_input: None,
tool_result: None,
tool_use_id: None,
};
let output = dispatch_hook(&req, &matchers, Duration::from_secs(30), None).await;
assert_eq!(output.decision, HookDecision::Block);
}
#[tokio::test]
async fn hook_timeout_override() {
let cb: HookCallback =
Arc::new(|_, _, _| Box::pin(async { HookOutput::block("custom timeout") }));
let matchers =
vec![HookMatcher::new(HookEvent::PreToolUse, cb).with_timeout(Duration::from_secs(60))];
let req = HookRequest {
request_id: "r1".into(),
hook_event: HookEvent::PreToolUse,
tool_name: None,
tool_input: None,
tool_result: None,
tool_use_id: None,
};
let output = dispatch_hook(&req, &matchers, Duration::from_millis(1), None).await;
assert_eq!(output.decision, HookDecision::Block);
}
#[tokio::test]
async fn hook_timeout_fires_returns_allow() {
let cb: HookCallback = Arc::new(|_, _, _| {
Box::pin(async {
tokio::time::sleep(Duration::from_secs(3600)).await;
HookOutput::block("never reached")
})
});
let matchers = vec![HookMatcher::new(HookEvent::PreToolUse, cb)];
let req = HookRequest {
request_id: "r1".into(),
hook_event: HookEvent::PreToolUse,
tool_name: Some("Bash".into()),
tool_input: None,
tool_result: None,
tool_use_id: None,
};
let output = dispatch_hook(&req, &matchers, Duration::from_millis(10), None).await;
assert_eq!(output.decision, HookDecision::Allow);
}
}