use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use typed_builder::TypedBuilder;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum HookEvent {
PreToolUse,
PostToolUse,
PostToolUseFailure,
UserPromptSubmit,
Stop,
SubagentStop,
PreCompact,
Notification,
SubagentStart,
PermissionRequest,
}
#[derive(Clone, TypedBuilder)]
#[builder(doc)]
pub struct HookMatcher {
#[builder(default, setter(into, strip_option))]
pub matcher: Option<String>,
#[builder(default)]
pub hooks: Vec<HookCallback>,
#[builder(default, setter(strip_option))]
pub timeout: Option<f64>,
}
pub type HookCallback = Arc<
dyn Fn(HookInput, Option<String>, HookContext) -> BoxFuture<'static, HookJsonOutput>
+ Send
+ Sync,
>;
pub type HookFn = fn(HookInput, Option<String>, HookContext) -> BoxFuture<'static, HookJsonOutput>;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "hook_event_name", rename_all = "PascalCase")]
pub enum HookInput {
PreToolUse(PreToolUseHookInput),
PostToolUse(PostToolUseHookInput),
PostToolUseFailure(PostToolUseFailureHookInput),
UserPromptSubmit(UserPromptSubmitHookInput),
Stop(StopHookInput),
SubagentStop(SubagentStopHookInput),
PreCompact(PreCompactHookInput),
Notification(NotificationHookInput),
SubagentStart(SubagentStartHookInput),
PermissionRequest(PermissionRequestHookInput),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreToolUseHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub tool_name: String,
pub tool_input: serde_json::Value,
pub tool_use_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostToolUseHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub tool_name: String,
pub tool_input: serde_json::Value,
pub tool_response: serde_json::Value,
pub tool_use_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserPromptSubmitHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub prompt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StopHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub stop_hook_active: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubagentStopHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub stop_hook_active: bool,
pub agent_id: String,
pub agent_transcript_path: String,
pub agent_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreCompactHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub trigger: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_instructions: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostToolUseFailureHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub tool_name: String,
pub tool_input: serde_json::Value,
pub tool_use_id: String,
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_interrupt: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NotificationHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
pub notification_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubagentStartHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub agent_id: String,
pub agent_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PermissionRequestHookInput {
pub session_id: String,
pub transcript_path: String,
pub cwd: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_mode: Option<String>,
pub tool_name: String,
pub tool_input: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_suggestions: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Clone, Default)]
pub struct HookContext {
pub signal: Option<()>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum HookJsonOutput {
Async(AsyncHookJsonOutput),
Sync(SyncHookJsonOutput),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AsyncHookJsonOutput {
#[serde(rename = "async")]
pub async_: bool,
#[serde(skip_serializing_if = "Option::is_none", rename = "asyncTimeout")]
pub async_timeout: Option<u64>,
}
impl Default for AsyncHookJsonOutput {
fn default() -> Self {
Self {
async_: true, async_timeout: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct SyncHookJsonOutput {
#[serde(skip_serializing_if = "Option::is_none", rename = "continue")]
#[builder(default, setter(strip_option))]
pub continue_: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none", rename = "suppressOutput")]
#[builder(default, setter(strip_option))]
pub suppress_output: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none", rename = "stopReason")]
#[builder(default, setter(into, strip_option))]
pub stop_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub decision: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "systemMessage")]
#[builder(default, setter(into, strip_option))]
pub system_message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "hookSpecificOutput")]
#[builder(default, setter(strip_option))]
pub hook_specific_output: Option<HookSpecificOutput>,
}
impl Default for SyncHookJsonOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "hookEventName")]
pub enum HookSpecificOutput {
PreToolUse(PreToolUseHookSpecificOutput),
PostToolUse(PostToolUseHookSpecificOutput),
PostToolUseFailure(PostToolUseFailureHookSpecificOutput),
UserPromptSubmit(UserPromptSubmitHookSpecificOutput),
Notification(NotificationHookSpecificOutput),
SubagentStart(SubagentStartHookSpecificOutput),
PermissionRequest(PermissionRequestHookSpecificOutput),
SessionStart(SessionStartHookSpecificOutput),
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct PreToolUseHookSpecificOutput {
#[serde(skip_serializing_if = "Option::is_none", rename = "permissionDecision")]
#[builder(default, setter(into, strip_option))]
pub permission_decision: Option<String>,
#[serde(
skip_serializing_if = "Option::is_none",
rename = "permissionDecisionReason"
)]
#[builder(default, setter(into, strip_option))]
pub permission_decision_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "updatedInput")]
#[builder(default, setter(strip_option))]
pub updated_input: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none", rename = "additionalContext")]
#[builder(default, setter(into, strip_option))]
pub additional_context: Option<String>,
}
impl Default for PreToolUseHookSpecificOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct PostToolUseHookSpecificOutput {
#[serde(skip_serializing_if = "Option::is_none", rename = "additionalContext")]
#[builder(default, setter(into, strip_option))]
pub additional_context: Option<String>,
#[serde(
skip_serializing_if = "Option::is_none",
rename = "updatedMCPToolOutput"
)]
#[builder(default, setter(strip_option))]
pub updated_mcp_tool_output: Option<serde_json::Value>,
}
impl Default for PostToolUseHookSpecificOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct UserPromptSubmitHookSpecificOutput {
#[serde(skip_serializing_if = "Option::is_none", rename = "additionalContext")]
#[builder(default, setter(into, strip_option))]
pub additional_context: Option<String>,
}
impl Default for UserPromptSubmitHookSpecificOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct PostToolUseFailureHookSpecificOutput {
#[serde(skip_serializing_if = "Option::is_none", rename = "additionalContext")]
#[builder(default, setter(into, strip_option))]
pub additional_context: Option<String>,
}
impl Default for PostToolUseFailureHookSpecificOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct NotificationHookSpecificOutput {
#[serde(skip_serializing_if = "Option::is_none", rename = "additionalContext")]
#[builder(default, setter(into, strip_option))]
pub additional_context: Option<String>,
}
impl Default for NotificationHookSpecificOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct SubagentStartHookSpecificOutput {
#[serde(skip_serializing_if = "Option::is_none", rename = "additionalContext")]
#[builder(default, setter(into, strip_option))]
pub additional_context: Option<String>,
}
impl Default for SubagentStartHookSpecificOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct PermissionRequestHookSpecificOutput {
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub decision: Option<serde_json::Value>,
}
impl Default for PermissionRequestHookSpecificOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct SessionStartHookSpecificOutput {
#[serde(skip_serializing_if = "Option::is_none", rename = "additionalContext")]
#[builder(default, setter(into, strip_option))]
pub additional_context: Option<String>,
}
impl Default for SessionStartHookSpecificOutput {
fn default() -> Self {
Self::builder().build()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_hook_event_serialization() {
assert_eq!(
serde_json::to_string(&HookEvent::PreToolUse).unwrap(),
"\"PreToolUse\""
);
assert_eq!(
serde_json::to_string(&HookEvent::PostToolUse).unwrap(),
"\"PostToolUse\""
);
assert_eq!(
serde_json::to_string(&HookEvent::UserPromptSubmit).unwrap(),
"\"UserPromptSubmit\""
);
assert_eq!(serde_json::to_string(&HookEvent::Stop).unwrap(), "\"Stop\"");
assert_eq!(
serde_json::to_string(&HookEvent::SubagentStop).unwrap(),
"\"SubagentStop\""
);
assert_eq!(
serde_json::to_string(&HookEvent::PreCompact).unwrap(),
"\"PreCompact\""
);
}
#[test]
fn test_pretooluse_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "PreToolUse",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"permission_mode": "default",
"tool_name": "Bash",
"tool_input": {"command": "echo hello"},
"tool_use_id": "tool_123"
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::PreToolUse(pre_tool) => {
assert_eq!(pre_tool.session_id, "test-session");
assert_eq!(pre_tool.tool_name, "Bash");
assert_eq!(pre_tool.tool_input["command"], "echo hello");
assert_eq!(pre_tool.tool_use_id, "tool_123");
}
_ => panic!("Expected PreToolUse variant"),
}
}
#[test]
fn test_posttooluse_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "PostToolUse",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"tool_name": "Bash",
"tool_input": {"command": "echo hello"},
"tool_response": "hello\n",
"tool_use_id": "tool_456"
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::PostToolUse(post_tool) => {
assert_eq!(post_tool.session_id, "test-session");
assert_eq!(post_tool.tool_name, "Bash");
assert_eq!(post_tool.tool_response, "hello\n");
assert_eq!(post_tool.tool_use_id, "tool_456");
}
_ => panic!("Expected PostToolUse variant"),
}
}
#[test]
fn test_stop_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "Stop",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"stop_hook_active": true
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::Stop(stop) => {
assert_eq!(stop.session_id, "test-session");
assert!(stop.stop_hook_active);
}
_ => panic!("Expected Stop variant"),
}
}
#[test]
fn test_subagent_stop_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "SubagentStop",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"stop_hook_active": false,
"agent_id": "agent-1",
"agent_transcript_path": "/path/to/agent/transcript",
"agent_type": "code"
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::SubagentStop(subagent) => {
assert_eq!(subagent.session_id, "test-session");
assert!(!subagent.stop_hook_active);
assert_eq!(subagent.agent_id, "agent-1");
assert_eq!(subagent.agent_transcript_path, "/path/to/agent/transcript");
assert_eq!(subagent.agent_type, "code");
}
_ => panic!("Expected SubagentStop variant"),
}
}
#[test]
fn test_precompact_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "PreCompact",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"trigger": "manual",
"custom_instructions": "Keep important details"
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::PreCompact(precompact) => {
assert_eq!(precompact.session_id, "test-session");
assert_eq!(precompact.trigger, "manual");
assert_eq!(
precompact.custom_instructions,
Some("Keep important details".to_string())
);
}
_ => panic!("Expected PreCompact variant"),
}
}
#[test]
fn test_sync_hook_output_serialization() {
let output = SyncHookJsonOutput {
continue_: Some(false),
stop_reason: Some("Test stop".to_string()),
..Default::default()
};
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["continue"], false);
assert_eq!(json["stopReason"], "Test stop");
}
#[test]
fn test_hook_specific_output_pretooluse_serialization() {
let output = HookSpecificOutput::PreToolUse(PreToolUseHookSpecificOutput {
permission_decision: Some("deny".to_string()),
permission_decision_reason: Some("Security policy".to_string()),
updated_input: None,
additional_context: None,
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["hookEventName"], "PreToolUse");
assert_eq!(json["permissionDecision"], "deny");
assert_eq!(json["permissionDecisionReason"], "Security policy");
}
#[test]
fn test_hook_specific_output_posttooluse_serialization() {
let output = HookSpecificOutput::PostToolUse(PostToolUseHookSpecificOutput {
additional_context: Some("Error occurred".to_string()),
updated_mcp_tool_output: None,
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["hookEventName"], "PostToolUse");
assert_eq!(json["additionalContext"], "Error occurred");
}
#[test]
fn test_hook_specific_output_userpromptsubmit_serialization() {
let output = HookSpecificOutput::UserPromptSubmit(UserPromptSubmitHookSpecificOutput {
additional_context: Some("Custom context".to_string()),
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["hookEventName"], "UserPromptSubmit");
assert_eq!(json["additionalContext"], "Custom context");
}
#[test]
fn test_complete_hook_output_with_pretooluse() {
let output = SyncHookJsonOutput {
continue_: Some(true),
hook_specific_output: Some(HookSpecificOutput::PreToolUse(
PreToolUseHookSpecificOutput {
permission_decision: Some("allow".to_string()),
permission_decision_reason: Some("Approved".to_string()),
updated_input: Some(json!({"modified": true})),
additional_context: None,
},
)),
..Default::default()
};
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["continue"], true);
assert_eq!(json["hookSpecificOutput"]["hookEventName"], "PreToolUse");
assert_eq!(json["hookSpecificOutput"]["permissionDecision"], "allow");
}
#[test]
fn test_optional_fields_omitted() {
let output = SyncHookJsonOutput::default();
let json = serde_json::to_value(&output).unwrap();
assert!(json.as_object().unwrap().is_empty());
}
#[test]
fn test_async_hook_output_serialization() {
let output = AsyncHookJsonOutput::default();
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["async"], true);
assert!(json.get("asyncTimeout").is_none());
}
#[test]
fn test_async_hook_output_with_timeout() {
let output = AsyncHookJsonOutput {
async_: true,
async_timeout: Some(5000),
};
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["async"], true);
assert_eq!(json["asyncTimeout"], 5000);
}
#[test]
fn test_hooks_builder_new() {
let hooks = Hooks::new();
let built = hooks.build();
assert!(built.is_empty());
}
#[test]
fn test_hooks_builder_add_pre_tool_use() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_pre_tool_use(test_hook);
let built = hooks.build();
assert_eq!(built.len(), 1);
assert!(built.contains_key(&HookEvent::PreToolUse));
let matchers = &built[&HookEvent::PreToolUse];
assert_eq!(matchers.len(), 1);
assert_eq!(matchers[0].matcher, None);
assert_eq!(matchers[0].hooks.len(), 1);
}
#[test]
fn test_hooks_builder_add_pre_tool_use_with_matcher() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_pre_tool_use_with_matcher("Bash", test_hook);
let built = hooks.build();
assert_eq!(built.len(), 1);
assert!(built.contains_key(&HookEvent::PreToolUse));
let matchers = &built[&HookEvent::PreToolUse];
assert_eq!(matchers.len(), 1);
assert_eq!(matchers[0].matcher, Some("Bash".to_string()));
assert_eq!(matchers[0].hooks.len(), 1);
}
#[test]
fn test_hooks_builder_multiple_hooks_same_event() {
async fn test_hook1(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
async fn test_hook2(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_pre_tool_use(test_hook1);
hooks.add_pre_tool_use_with_matcher("Bash", test_hook2);
let built = hooks.build();
assert_eq!(built.len(), 1);
assert!(built.contains_key(&HookEvent::PreToolUse));
let matchers = &built[&HookEvent::PreToolUse];
assert_eq!(matchers.len(), 2);
assert_eq!(matchers[0].matcher, None);
assert_eq!(matchers[1].matcher, Some("Bash".to_string()));
}
#[test]
fn test_hooks_builder_add_post_tool_use() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_post_tool_use(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::PostToolUse));
assert_eq!(built[&HookEvent::PostToolUse][0].matcher, None);
}
#[test]
fn test_hooks_builder_add_post_tool_use_with_matcher() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_post_tool_use_with_matcher("Write", test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::PostToolUse));
assert_eq!(
built[&HookEvent::PostToolUse][0].matcher,
Some("Write".to_string())
);
}
#[test]
fn test_hooks_builder_add_user_prompt_submit() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_user_prompt_submit(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::UserPromptSubmit));
assert_eq!(built[&HookEvent::UserPromptSubmit][0].matcher, None);
}
#[test]
fn test_hooks_builder_add_stop() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_stop(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::Stop));
assert_eq!(built[&HookEvent::Stop][0].matcher, None);
}
#[test]
fn test_hooks_builder_add_subagent_stop() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_subagent_stop(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::SubagentStop));
assert_eq!(built[&HookEvent::SubagentStop][0].matcher, None);
}
#[test]
fn test_hooks_builder_add_pre_compact() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_pre_compact(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::PreCompact));
assert_eq!(built[&HookEvent::PreCompact][0].matcher, None);
}
#[test]
fn test_hooks_builder_multiple_event_types() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_pre_tool_use(test_hook);
hooks.add_post_tool_use(test_hook);
hooks.add_user_prompt_submit(test_hook);
hooks.add_stop(test_hook);
let built = hooks.build();
assert_eq!(built.len(), 4);
assert!(built.contains_key(&HookEvent::PreToolUse));
assert!(built.contains_key(&HookEvent::PostToolUse));
assert!(built.contains_key(&HookEvent::UserPromptSubmit));
assert!(built.contains_key(&HookEvent::Stop));
}
#[tokio::test]
async fn test_hook_execution_returns_sync_output() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput {
continue_: Some(true),
..Default::default()
})
}
let mut hooks = Hooks::new();
hooks.add_pre_tool_use(test_hook);
let built = hooks.build();
let hook_callback = &built[&HookEvent::PreToolUse][0].hooks[0];
let input = HookInput::PreToolUse(PreToolUseHookInput {
session_id: "test".to_string(),
transcript_path: "/tmp/test".to_string(),
cwd: "/tmp".to_string(),
permission_mode: None,
tool_name: "Bash".to_string(),
tool_input: serde_json::json!({"command": "ls"}),
tool_use_id: "tool_789".to_string(),
});
let result = hook_callback(input, None, HookContext::default()).await;
match result {
HookJsonOutput::Sync(output) => {
assert_eq!(output.continue_, Some(true));
}
_ => panic!("Expected sync output"),
}
}
#[tokio::test]
async fn test_hook_execution_returns_async_output() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Async(AsyncHookJsonOutput {
async_: true,
async_timeout: Some(5000),
})
}
let mut hooks = Hooks::new();
hooks.add_pre_tool_use(test_hook);
let built = hooks.build();
let hook_callback = &built[&HookEvent::PreToolUse][0].hooks[0];
let input = HookInput::PreToolUse(PreToolUseHookInput {
session_id: "test".to_string(),
transcript_path: "/tmp/test".to_string(),
cwd: "/tmp".to_string(),
permission_mode: None,
tool_name: "Bash".to_string(),
tool_input: serde_json::json!({"command": "ls"}),
tool_use_id: "tool_789".to_string(),
});
let result = hook_callback(input, None, HookContext::default()).await;
match result {
HookJsonOutput::Async(output) => {
assert!(output.async_);
assert_eq!(output.async_timeout, Some(5000));
}
_ => panic!("Expected async output"),
}
}
#[test]
fn test_hooks_builder_matcher_accepts_string_types() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_pre_tool_use_with_matcher("Bash", test_hook);
hooks.add_pre_tool_use_with_matcher("Write".to_string(), test_hook);
let built = hooks.build();
let matchers = &built[&HookEvent::PreToolUse];
assert_eq!(matchers.len(), 2);
assert_eq!(matchers[0].matcher, Some("Bash".to_string()));
assert_eq!(matchers[1].matcher, Some("Write".to_string()));
}
#[test]
fn test_new_hook_event_serialization() {
assert_eq!(
serde_json::to_string(&HookEvent::PostToolUseFailure).unwrap(),
"\"PostToolUseFailure\""
);
assert_eq!(
serde_json::to_string(&HookEvent::Notification).unwrap(),
"\"Notification\""
);
assert_eq!(
serde_json::to_string(&HookEvent::SubagentStart).unwrap(),
"\"SubagentStart\""
);
assert_eq!(
serde_json::to_string(&HookEvent::PermissionRequest).unwrap(),
"\"PermissionRequest\""
);
}
#[test]
fn test_post_tool_use_failure_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "PostToolUseFailure",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"tool_name": "Bash",
"tool_input": {"command": "invalid"},
"tool_use_id": "tool_fail_1",
"error": "Command not found",
"is_interrupt": false
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::PostToolUseFailure(failure) => {
assert_eq!(failure.session_id, "test-session");
assert_eq!(failure.tool_name, "Bash");
assert_eq!(failure.tool_use_id, "tool_fail_1");
assert_eq!(failure.error, "Command not found");
assert_eq!(failure.is_interrupt, Some(false));
}
_ => panic!("Expected PostToolUseFailure variant"),
}
}
#[test]
fn test_notification_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "Notification",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"message": "Task completed",
"title": "Done",
"notification_type": "info"
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::Notification(notif) => {
assert_eq!(notif.session_id, "test-session");
assert_eq!(notif.message, "Task completed");
assert_eq!(notif.title, Some("Done".to_string()));
assert_eq!(notif.notification_type, "info");
}
_ => panic!("Expected Notification variant"),
}
}
#[test]
fn test_subagent_start_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "SubagentStart",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"agent_id": "agent-42",
"agent_type": "code"
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::SubagentStart(start) => {
assert_eq!(start.session_id, "test-session");
assert_eq!(start.agent_id, "agent-42");
assert_eq!(start.agent_type, "code");
}
_ => panic!("Expected SubagentStart variant"),
}
}
#[test]
fn test_permission_request_hook_input_deserialization() {
let json_str = r#"{
"hook_event_name": "PermissionRequest",
"session_id": "test-session",
"transcript_path": "/path/to/transcript",
"cwd": "/working/dir",
"tool_name": "Write",
"tool_input": {"path": "/etc/hosts"},
"permission_suggestions": [{"type": "allow"}]
}"#;
let input: HookInput = serde_json::from_str(json_str).unwrap();
match input {
HookInput::PermissionRequest(perm) => {
assert_eq!(perm.session_id, "test-session");
assert_eq!(perm.tool_name, "Write");
assert!(perm.permission_suggestions.is_some());
}
_ => panic!("Expected PermissionRequest variant"),
}
}
#[test]
fn test_new_hook_specific_outputs_serialization() {
let output = HookSpecificOutput::PostToolUseFailure(PostToolUseFailureHookSpecificOutput {
additional_context: Some("Retry suggested".to_string()),
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["hookEventName"], "PostToolUseFailure");
assert_eq!(json["additionalContext"], "Retry suggested");
let output = HookSpecificOutput::Notification(NotificationHookSpecificOutput {
additional_context: Some("Acknowledged".to_string()),
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["hookEventName"], "Notification");
assert_eq!(json["additionalContext"], "Acknowledged");
let output = HookSpecificOutput::SubagentStart(SubagentStartHookSpecificOutput {
additional_context: None,
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["hookEventName"], "SubagentStart");
let output = HookSpecificOutput::PermissionRequest(PermissionRequestHookSpecificOutput {
decision: Some(json!({"allow": true})),
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["hookEventName"], "PermissionRequest");
assert_eq!(json["decision"]["allow"], true);
}
#[test]
fn test_pretooluse_additional_context_serialization() {
let output = HookSpecificOutput::PreToolUse(PreToolUseHookSpecificOutput {
permission_decision: Some("allow".to_string()),
permission_decision_reason: None,
updated_input: None,
additional_context: Some("Extra info".to_string()),
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["additionalContext"], "Extra info");
}
#[test]
fn test_posttooluse_updated_mcp_tool_output_serialization() {
let output = HookSpecificOutput::PostToolUse(PostToolUseHookSpecificOutput {
additional_context: None,
updated_mcp_tool_output: Some(json!({"modified": true})),
});
let json = serde_json::to_value(&output).unwrap();
assert_eq!(json["updatedMCPToolOutput"]["modified"], true);
}
#[test]
fn test_hooks_builder_add_post_tool_use_failure() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_post_tool_use_failure(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::PostToolUseFailure));
assert_eq!(built[&HookEvent::PostToolUseFailure][0].matcher, None);
}
#[test]
fn test_hooks_builder_add_post_tool_use_failure_with_matcher() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_post_tool_use_failure_with_matcher("Bash", test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::PostToolUseFailure));
assert_eq!(
built[&HookEvent::PostToolUseFailure][0].matcher,
Some("Bash".to_string())
);
}
#[test]
fn test_hooks_builder_add_notification() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_notification(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::Notification));
assert_eq!(built[&HookEvent::Notification][0].matcher, None);
}
#[test]
fn test_hooks_builder_add_subagent_start() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_subagent_start(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::SubagentStart));
assert_eq!(built[&HookEvent::SubagentStart][0].matcher, None);
}
#[test]
fn test_hooks_builder_add_permission_request() {
async fn test_hook(
_input: HookInput,
_tool_use_id: Option<String>,
_context: HookContext,
) -> HookJsonOutput {
HookJsonOutput::Sync(SyncHookJsonOutput::default())
}
let mut hooks = Hooks::new();
hooks.add_permission_request(test_hook);
let built = hooks.build();
assert!(built.contains_key(&HookEvent::PermissionRequest));
assert_eq!(built[&HookEvent::PermissionRequest][0].matcher, None);
}
}
macro_rules! generate_hook_methods {
(
with_matcher: {
$($event_m:ident => $method_name_m:ident: $doc_m:expr),* $(,)?
},
without_matcher: {
$($event:ident => $method_name:ident: $doc:expr),* $(,)?
} $(,)?
) => {
$(
generate_hook_methods!(@with_matcher $event_m, $method_name_m, $doc_m);
)*
$(
generate_hook_methods!(@no_matcher $event, $method_name, $doc);
)*
};
(@with_matcher $event:ident, $method_name:ident, $doc:expr) => {
#[doc = $doc]
pub fn $method_name<F, Fut>(&mut self, hook_fn: F)
where
F: Fn(HookInput, Option<String>, HookContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = HookJsonOutput> + Send + 'static,
{
let wrapper = move |input: HookInput, tool_use_id: Option<String>, context: HookContext| {
Box::pin(hook_fn(input, tool_use_id, context)) as BoxFuture<'static, HookJsonOutput>
};
self.add_hook(HookEvent::$event, None::<String>, wrapper);
}
paste::paste! {
#[doc = $doc]
#[doc = " with a matcher pattern."]
#[doc = ""]
#[doc = "# Arguments"]
#[doc = "* `matcher` - Tool name to match (e.g., \"Bash\", \"Write\")"]
#[doc = "* `hook_fn` - The hook function to call"]
pub fn [<$method_name _with_matcher>]<F, Fut>(&mut self, matcher: impl Into<String>, hook_fn: F)
where
F: Fn(HookInput, Option<String>, HookContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = HookJsonOutput> + Send + 'static,
{
let wrapper = move |input: HookInput, tool_use_id: Option<String>, context: HookContext| {
Box::pin(hook_fn(input, tool_use_id, context)) as BoxFuture<'static, HookJsonOutput>
};
self.add_hook(HookEvent::$event, Some(matcher), wrapper);
}
}
};
(@no_matcher $event:ident, $method_name:ident, $doc:expr) => {
#[doc = $doc]
pub fn $method_name<F, Fut>(&mut self, hook_fn: F)
where
F: Fn(HookInput, Option<String>, HookContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = HookJsonOutput> + Send + 'static,
{
let wrapper = move |input: HookInput, tool_use_id: Option<String>, context: HookContext| {
Box::pin(hook_fn(input, tool_use_id, context)) as BoxFuture<'static, HookJsonOutput>
};
self.add_hook(HookEvent::$event, None::<String>, wrapper);
}
};
}
#[derive(Default)]
pub struct Hooks {
hooks: HashMap<HookEvent, Vec<HookMatcher>>,
}
impl Hooks {
pub fn new() -> Self {
Self::default()
}
pub fn build(self) -> HashMap<HookEvent, Vec<HookMatcher>> {
self.hooks
}
fn add_hook<F>(&mut self, event: HookEvent, matcher: Option<impl Into<String>>, hook_fn: F)
where
F: Fn(HookInput, Option<String>, HookContext) -> BoxFuture<'static, HookJsonOutput>
+ Send
+ Sync
+ 'static,
{
let matcher_string = matcher.map(|m| m.into());
let hook_callback = Arc::new(hook_fn);
self.hooks.entry(event).or_default().push(HookMatcher {
matcher: matcher_string,
hooks: vec![hook_callback],
timeout: None,
});
}
generate_hook_methods! {
with_matcher: {
PreToolUse => add_pre_tool_use: "Add a PreToolUse hook that fires before tool execution.",
PostToolUse => add_post_tool_use: "Add a PostToolUse hook that fires after tool execution.",
PostToolUseFailure => add_post_tool_use_failure: "Add a PostToolUseFailure hook that fires when a tool execution fails.",
},
without_matcher: {
UserPromptSubmit => add_user_prompt_submit: "Add a UserPromptSubmit hook that fires when user submits a prompt.",
Stop => add_stop: "Add a Stop hook that fires when execution stops.",
SubagentStop => add_subagent_stop: "Add a SubagentStop hook that fires when a subagent stops.",
PreCompact => add_pre_compact: "Add a PreCompact hook that fires before conversation compaction.",
Notification => add_notification: "Add a Notification hook that fires for notification events.",
SubagentStart => add_subagent_start: "Add a SubagentStart hook that fires when a subagent starts.",
PermissionRequest => add_permission_request: "Add a PermissionRequest hook that fires for permission request events.",
},
}
}