use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::tool::ToolOutput;
use crate::types::{SessionId, ToolCallId};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
PreToolUse,
PostToolUse,
PostToolFailure,
UserPromptSubmit,
SessionStart,
SessionEnd,
Stop,
SubAgentStart,
PreCompact,
PostCompact,
}
pub const ALL_HOOK_EVENTS: &[HookEvent] = &[
HookEvent::PreToolUse,
HookEvent::PostToolUse,
HookEvent::PostToolFailure,
HookEvent::UserPromptSubmit,
HookEvent::SessionStart,
HookEvent::SessionEnd,
HookEvent::Stop,
HookEvent::SubAgentStart,
HookEvent::PreCompact,
HookEvent::PostCompact,
];
impl HookEvent {
pub fn is_tool_event(&self) -> bool {
matches!(
self,
Self::PreToolUse | Self::PostToolUse | Self::PostToolFailure
)
}
}
impl std::fmt::Display for HookEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::PreToolUse => "pre_tool_use",
Self::PostToolUse => "post_tool_use",
Self::PostToolFailure => "post_tool_failure",
Self::UserPromptSubmit => "user_prompt_submit",
Self::SessionStart => "session_start",
Self::SessionEnd => "session_end",
Self::Stop => "stop",
Self::SubAgentStart => "sub_agent_start",
Self::PreCompact => "pre_compact",
Self::PostCompact => "post_compact",
};
f.write_str(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "hook_event", rename_all = "snake_case")]
pub enum HookInput {
PreToolUse {
tool_name: String,
tool_input: serde_json::Value,
call_id: ToolCallId,
},
PostToolUse {
tool_name: String,
tool_input: serde_json::Value,
tool_output: ToolOutput,
call_id: ToolCallId,
},
PostToolFailure {
tool_name: String,
tool_input: serde_json::Value,
error: String,
call_id: ToolCallId,
},
UserPromptSubmit {
prompt: String,
},
SessionStart {
session_id: SessionId,
},
SessionEnd {
session_id: SessionId,
reason: String,
},
Stop {
finish_reason: String,
},
SubAgentStart {
agent_name: String,
},
PreCompact {
trigger: String,
tokens_before: u64,
},
PostCompact {
trigger: String,
tokens_after: u64,
},
}
impl HookInput {
pub fn event(&self) -> HookEvent {
match self {
Self::PreToolUse { .. } => HookEvent::PreToolUse,
Self::PostToolUse { .. } => HookEvent::PostToolUse,
Self::PostToolFailure { .. } => HookEvent::PostToolFailure,
Self::UserPromptSubmit { .. } => HookEvent::UserPromptSubmit,
Self::SessionStart { .. } => HookEvent::SessionStart,
Self::SessionEnd { .. } => HookEvent::SessionEnd,
Self::Stop { .. } => HookEvent::Stop,
Self::SubAgentStart { .. } => HookEvent::SubAgentStart,
Self::PreCompact { .. } => HookEvent::PreCompact,
Self::PostCompact { .. } => HookEvent::PostCompact,
}
}
pub fn tool_name(&self) -> Option<&str> {
match self {
Self::PreToolUse { tool_name, .. }
| Self::PostToolUse { tool_name, .. }
| Self::PostToolFailure { tool_name, .. } => Some(tool_name.as_str()),
_ => None,
}
}
pub fn call_id(&self) -> Option<&ToolCallId> {
match self {
Self::PreToolUse { call_id, .. }
| Self::PostToolUse { call_id, .. }
| Self::PostToolFailure { call_id, .. } => Some(call_id),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum HookPermission {
Allow,
Deny {
#[serde(default, skip_serializing_if = "Option::is_none")]
reason: Option<String>,
},
Ask {
#[serde(default, skip_serializing_if = "Option::is_none")]
message: Option<String>,
},
}
impl HookPermission {
pub fn deny() -> Self {
Self::Deny { reason: None }
}
pub fn deny_with_reason(reason: impl Into<String>) -> Self {
Self::Deny {
reason: Some(reason.into()),
}
}
pub fn ask() -> Self {
Self::Ask { message: None }
}
pub fn ask_with_message(message: impl Into<String>) -> Self {
Self::Ask {
message: Some(message.into()),
}
}
pub fn is_allow(&self) -> bool {
matches!(self, Self::Allow)
}
pub fn is_deny(&self) -> bool {
matches!(self, Self::Deny { .. })
}
pub fn is_ask(&self) -> bool {
matches!(self, Self::Ask { .. })
}
fn strictness(&self) -> u8 {
match self {
Self::Allow => 0,
Self::Ask { .. } => 1,
Self::Deny { .. } => 2,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HookOutput {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub permission: Option<HookPermission>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub updated_input: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub updated_output: Option<ToolOutput>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub additional_context: Vec<String>,
#[serde(default)]
pub prevent_continuation: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub blocking_error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system_message: Option<String>,
}
impl HookOutput {
pub fn passthrough() -> Self {
Self::default()
}
pub fn allow() -> Self {
Self {
permission: Some(HookPermission::Allow),
..Default::default()
}
}
pub fn deny(reason: impl Into<String>) -> Self {
Self {
permission: Some(HookPermission::deny_with_reason(reason)),
..Default::default()
}
}
pub fn ask(message: impl Into<String>) -> Self {
Self {
permission: Some(HookPermission::ask_with_message(message)),
..Default::default()
}
}
pub fn with_updated_input(mut self, input: serde_json::Value) -> Self {
self.updated_input = Some(input);
self
}
pub fn with_updated_output(mut self, output: ToolOutput) -> Self {
self.updated_output = Some(output);
self
}
pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
self.additional_context.push(ctx.into());
self
}
pub fn with_stop(mut self, reason: impl Into<String>) -> Self {
self.prevent_continuation = true;
self.stop_reason = Some(reason.into());
self
}
pub fn with_blocking_error(mut self, error: impl Into<String>) -> Self {
self.blocking_error = Some(error.into());
self
}
pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
self.system_message = Some(message.into());
self
}
pub fn has_decision(&self) -> bool {
self.permission.is_some()
|| self.updated_input.is_some()
|| self.updated_output.is_some()
|| !self.additional_context.is_empty()
|| self.prevent_continuation
|| self.blocking_error.is_some()
}
}
#[async_trait]
pub trait Hook: Send + Sync {
fn name(&self) -> &str;
fn events(&self) -> &[HookEvent] {
&[]
}
fn matcher(&self) -> Option<&str> {
None
}
async fn on_event(&self, input: &HookInput) -> HookOutput;
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum HookSource {
Settings,
Project,
Plugin { name: String },
Programmatic,
Session,
}
pub struct HookRegistry {
hooks: Vec<RegisteredHook>,
}
pub struct RegisteredHook {
pub hook: Arc<dyn Hook>,
pub source: HookSource,
pub priority: i32,
}
impl HookRegistry {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
pub fn register(
&mut self,
hook: Arc<dyn Hook>,
source: HookSource,
priority: i32,
) {
self.hooks.push(RegisteredHook {
hook,
source,
priority,
});
self.hooks.sort_by_key(|h| h.priority);
}
pub fn remove(&mut self, name: &str) {
self.hooks.retain(|h| h.hook.name() != name);
}
pub fn len(&self) -> usize {
self.hooks.len()
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
pub fn matching(&self, input: &HookInput) -> Vec<&RegisteredHook> {
let event = input.event();
let tool_name = input.tool_name();
self.hooks
.iter()
.filter(|h| {
let events = h.hook.events();
let event_match = events.is_empty() || events.contains(&event);
if !event_match {
return false;
}
match (h.hook.matcher(), tool_name) {
(Some(pattern), Some(name)) => matches_pattern(name, pattern),
(Some(_), None) => false,
(None, _) => true,
}
})
.collect()
}
pub fn has_hooks_for(&self, event: HookEvent) -> bool {
self.hooks.iter().any(|h| {
let events = h.hook.events();
events.is_empty() || events.contains(&event)
})
}
}
impl Default for HookRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct AggregatedHookOutput {
pub permission: Option<HookPermission>,
pub updated_input: Option<serde_json::Value>,
pub updated_output: Option<ToolOutput>,
pub additional_context: Vec<String>,
pub prevent_continuation: bool,
pub stop_reason: Option<String>,
pub blocking_errors: Vec<String>,
pub system_messages: Vec<String>,
}
impl AggregatedHookOutput {
pub fn merge(&mut self, output: HookOutput, hook_name: &str) {
if let Some(ref new_perm) = output.permission {
match &self.permission {
Some(existing) if existing.strictness() >= new_perm.strictness() => {
}
_ => {
self.permission = output.permission.clone();
}
}
}
if output.updated_input.is_some() {
self.updated_input = output.updated_input;
}
if output.updated_output.is_some() {
self.updated_output = output.updated_output;
}
self.additional_context.extend(output.additional_context);
if output.prevent_continuation {
self.prevent_continuation = true;
if self.stop_reason.is_none() {
self.stop_reason = output.stop_reason;
}
}
if let Some(err) = output.blocking_error {
self.blocking_errors.push(format!("[{hook_name}] {err}"));
}
if let Some(msg) = output.system_message {
self.system_messages.push(msg);
}
}
pub fn has_decision(&self) -> bool {
self.permission.is_some()
|| self.updated_input.is_some()
|| self.updated_output.is_some()
|| !self.additional_context.is_empty()
|| self.prevent_continuation
|| !self.blocking_errors.is_empty()
}
pub fn has_blocking_errors(&self) -> bool {
!self.blocking_errors.is_empty()
}
pub fn is_denied(&self) -> bool {
matches!(&self.permission, Some(p) if p.is_deny())
}
}
pub fn matches_pattern(value: &str, pattern: &str) -> bool {
if pattern.contains('|') {
return pattern.split('|').any(|p| matches_single_pattern(value, p.trim()));
}
matches_single_pattern(value, pattern)
}
fn matches_single_pattern(value: &str, pattern: &str) -> bool {
if !pattern.contains('*') {
return value == pattern;
}
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 2 && parts[0].is_empty() && parts[1].is_empty() {
return true;
}
if parts.len() == 2 && parts[1].is_empty() {
return value.starts_with(parts[0]);
}
if parts.len() == 2 && parts[0].is_empty() {
return value.ends_with(parts[1]);
}
if parts.len() == 2 {
return value.starts_with(parts[0])
&& value.ends_with(parts[1])
&& value.len() >= parts[0].len() + parts[1].len();
}
let mut remaining = value;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 {
if !remaining.starts_with(part) {
return false;
}
remaining = &remaining[part.len()..];
} else if let Some(pos) = remaining.find(part) {
remaining = &remaining[pos + part.len()..];
} else {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_hook_event_is_tool_event() {
assert!(HookEvent::PreToolUse.is_tool_event());
assert!(HookEvent::PostToolUse.is_tool_event());
assert!(HookEvent::PostToolFailure.is_tool_event());
assert!(!HookEvent::SessionStart.is_tool_event());
assert!(!HookEvent::Stop.is_tool_event());
}
#[test]
fn test_hook_event_display() {
assert_eq!(HookEvent::PreToolUse.to_string(), "pre_tool_use");
assert_eq!(HookEvent::PostToolUse.to_string(), "post_tool_use");
assert_eq!(HookEvent::SessionStart.to_string(), "session_start");
}
#[test]
fn test_hook_event_serde_roundtrip() {
for event in ALL_HOOK_EVENTS {
let json_str = serde_json::to_string(event).unwrap();
let restored: HookEvent = serde_json::from_str(&json_str).unwrap();
assert_eq!(*event, restored);
}
}
#[test]
fn test_all_hook_events_count() {
assert_eq!(ALL_HOOK_EVENTS.len(), 10);
}
#[test]
fn test_hook_input_event() {
let input = HookInput::PreToolUse {
tool_name: "bash".into(),
tool_input: json!({}),
call_id: ToolCallId::new("c1"),
};
assert_eq!(input.event(), HookEvent::PreToolUse);
}
#[test]
fn test_hook_input_tool_name() {
let tool_input = HookInput::PreToolUse {
tool_name: "bash".into(),
tool_input: json!({}),
call_id: ToolCallId::new("c1"),
};
assert_eq!(tool_input.tool_name(), Some("bash"));
let non_tool = HookInput::SessionStart {
session_id: SessionId::new(),
};
assert_eq!(non_tool.tool_name(), None);
}
#[test]
fn test_hook_input_call_id() {
let input = HookInput::PostToolFailure {
tool_name: "bash".into(),
tool_input: json!({}),
error: "exit code 1".into(),
call_id: ToolCallId::new("c2"),
};
assert_eq!(input.call_id().unwrap().as_str(), "c2");
let non_tool = HookInput::Stop {
finish_reason: "completed".into(),
};
assert!(non_tool.call_id().is_none());
}
#[test]
fn test_hook_input_serde_roundtrip() {
let input = HookInput::PreToolUse {
tool_name: "read_file".into(),
tool_input: json!({"path": "/tmp/test.txt"}),
call_id: ToolCallId::new("call_42"),
};
let json_str = serde_json::to_string(&input).unwrap();
assert!(json_str.contains("pre_tool_use"));
let restored: HookInput = serde_json::from_str(&json_str).unwrap();
assert_eq!(restored.event(), HookEvent::PreToolUse);
assert_eq!(restored.tool_name(), Some("read_file"));
}
#[test]
fn test_hook_permission_variants() {
assert!(HookPermission::Allow.is_allow());
assert!(HookPermission::deny().is_deny());
assert!(HookPermission::ask().is_ask());
}
#[test]
fn test_hook_permission_with_reason() {
let deny = HookPermission::deny_with_reason("unsafe");
match deny {
HookPermission::Deny { reason } => assert_eq!(reason, Some("unsafe".into())),
_ => panic!("expected Deny"),
}
}
#[test]
fn test_hook_permission_strictness() {
assert!(HookPermission::Allow.strictness() < HookPermission::ask().strictness());
assert!(HookPermission::ask().strictness() < HookPermission::deny().strictness());
}
#[test]
fn test_hook_permission_serde_roundtrip() {
for perm in [
HookPermission::Allow,
HookPermission::deny(),
HookPermission::deny_with_reason("test"),
HookPermission::ask(),
HookPermission::ask_with_message("confirm?"),
] {
let json_str = serde_json::to_string(&perm).unwrap();
let restored: HookPermission = serde_json::from_str(&json_str).unwrap();
assert_eq!(perm, restored);
}
}
#[test]
fn test_hook_output_passthrough() {
let out = HookOutput::passthrough();
assert!(!out.has_decision());
assert!(out.permission.is_none());
assert!(out.additional_context.is_empty());
}
#[test]
fn test_hook_output_allow() {
let out = HookOutput::allow();
assert!(out.has_decision());
assert!(out.permission.as_ref().unwrap().is_allow());
}
#[test]
fn test_hook_output_deny() {
let out = HookOutput::deny("bad command");
assert!(out.has_decision());
assert!(out.permission.as_ref().unwrap().is_deny());
}
#[test]
fn test_hook_output_ask() {
let out = HookOutput::ask("are you sure?");
assert!(out.has_decision());
assert!(out.permission.as_ref().unwrap().is_ask());
}
#[test]
fn test_hook_output_builder() {
let out = HookOutput::allow()
.with_updated_input(json!({"command": "ls"}))
.with_context("working directory: /tmp")
.with_system_message("Input sanitized");
assert!(out.permission.as_ref().unwrap().is_allow());
assert_eq!(out.updated_input.as_ref().unwrap()["command"], "ls");
assert_eq!(out.additional_context.len(), 1);
assert_eq!(out.system_message, Some("Input sanitized".into()));
}
#[test]
fn test_hook_output_with_stop() {
let out = HookOutput::passthrough().with_stop("loop detected");
assert!(out.prevent_continuation);
assert_eq!(out.stop_reason, Some("loop detected".into()));
assert!(out.has_decision());
}
#[test]
fn test_hook_output_with_blocking_error() {
let out = HookOutput::passthrough().with_blocking_error("lint failed");
assert!(out.has_decision());
assert_eq!(out.blocking_error, Some("lint failed".into()));
}
#[test]
fn test_hook_output_serde_roundtrip() {
let out = HookOutput::deny("test")
.with_context("ctx1")
.with_system_message("msg1");
let json_str = serde_json::to_string(&out).unwrap();
let restored: HookOutput = serde_json::from_str(&json_str).unwrap();
assert_eq!(restored.additional_context, vec!["ctx1"]);
assert_eq!(restored.system_message, Some("msg1".into()));
}
#[test]
fn test_aggregated_merge_permission_deny_wins() {
let mut agg = AggregatedHookOutput::default();
agg.merge(HookOutput::allow(), "hook_a");
assert!(agg.permission.as_ref().unwrap().is_allow());
agg.merge(HookOutput::deny("nope"), "hook_b");
assert!(agg.permission.as_ref().unwrap().is_deny());
agg.merge(HookOutput::allow(), "hook_c");
assert!(agg.permission.as_ref().unwrap().is_deny());
}
#[test]
fn test_aggregated_merge_permission_ask_beats_allow() {
let mut agg = AggregatedHookOutput::default();
agg.merge(HookOutput::allow(), "hook_a");
agg.merge(HookOutput::ask("confirm?"), "hook_b");
assert!(agg.permission.as_ref().unwrap().is_ask());
agg.merge(HookOutput::allow(), "hook_c");
assert!(agg.permission.as_ref().unwrap().is_ask());
}
#[test]
fn test_aggregated_merge_context() {
let mut agg = AggregatedHookOutput::default();
agg.merge(
HookOutput::passthrough().with_context("ctx1"),
"hook_a",
);
agg.merge(
HookOutput::passthrough().with_context("ctx2"),
"hook_b",
);
assert_eq!(agg.additional_context, vec!["ctx1", "ctx2"]);
}
#[test]
fn test_aggregated_merge_blocking_errors() {
let mut agg = AggregatedHookOutput::default();
agg.merge(
HookOutput::passthrough().with_blocking_error("err1"),
"linter",
);
agg.merge(
HookOutput::passthrough().with_blocking_error("err2"),
"validator",
);
assert_eq!(agg.blocking_errors.len(), 2);
assert!(agg.blocking_errors[0].contains("[linter]"));
assert!(agg.blocking_errors[1].contains("[validator]"));
assert!(agg.has_blocking_errors());
}
#[test]
fn test_aggregated_merge_stop() {
let mut agg = AggregatedHookOutput::default();
agg.merge(HookOutput::passthrough(), "hook_a");
assert!(!agg.prevent_continuation);
agg.merge(
HookOutput::passthrough().with_stop("first reason"),
"hook_b",
);
assert!(agg.prevent_continuation);
assert_eq!(agg.stop_reason, Some("first reason".into()));
agg.merge(
HookOutput::passthrough().with_stop("second reason"),
"hook_c",
);
assert_eq!(agg.stop_reason, Some("first reason".into()));
}
#[test]
fn test_aggregated_merge_updated_input_last_wins() {
let mut agg = AggregatedHookOutput::default();
agg.merge(
HookOutput::allow().with_updated_input(json!({"a": 1})),
"hook_a",
);
agg.merge(
HookOutput::allow().with_updated_input(json!({"b": 2})),
"hook_b",
);
assert_eq!(agg.updated_input, Some(json!({"b": 2})));
}
#[test]
fn test_aggregated_has_decision() {
let agg = AggregatedHookOutput::default();
assert!(!agg.has_decision());
let mut agg2 = AggregatedHookOutput::default();
agg2.merge(HookOutput::allow(), "h");
assert!(agg2.has_decision());
}
#[test]
fn test_aggregated_is_denied() {
let mut agg = AggregatedHookOutput::default();
assert!(!agg.is_denied());
agg.merge(HookOutput::deny("no"), "h");
assert!(agg.is_denied());
}
#[test]
fn test_matches_pattern_exact() {
assert!(matches_pattern("bash", "bash"));
assert!(!matches_pattern("bash", "write_file"));
}
#[test]
fn test_matches_pattern_pipe_separated() {
assert!(matches_pattern("bash", "bash|write_file"));
assert!(matches_pattern("write_file", "bash|write_file"));
assert!(!matches_pattern("read_file", "bash|write_file"));
}
#[test]
fn test_matches_pattern_wildcard_star() {
assert!(matches_pattern("read_file", "read_*"));
assert!(matches_pattern("read_dir", "read_*"));
assert!(!matches_pattern("write_file", "read_*"));
}
#[test]
fn test_matches_pattern_wildcard_suffix() {
assert!(matches_pattern("read_file", "*_file"));
assert!(matches_pattern("write_file", "*_file"));
assert!(!matches_pattern("read_dir", "*_file"));
}
#[test]
fn test_matches_pattern_wildcard_middle() {
assert!(matches_pattern("pre_tool_use", "pre_*_use"));
assert!(matches_pattern("pre_compact_use", "pre_*_use"));
assert!(!matches_pattern("pre_tool_fail", "pre_*_use"));
}
#[test]
fn test_matches_pattern_star_matches_all() {
assert!(matches_pattern("anything", "*"));
assert!(matches_pattern("", "*"));
}
#[test]
fn test_matches_pattern_pipe_with_wildcard() {
assert!(matches_pattern("read_file", "bash|read_*"));
assert!(matches_pattern("bash", "bash|read_*"));
assert!(!matches_pattern("write_file", "bash|read_*"));
}
struct PassthroughHook {
hook_name: String,
hook_events: Vec<HookEvent>,
hook_matcher: Option<String>,
}
impl PassthroughHook {
fn new(name: &str) -> Self {
Self {
hook_name: name.into(),
hook_events: vec![],
hook_matcher: None,
}
}
fn with_events(mut self, events: Vec<HookEvent>) -> Self {
self.hook_events = events;
self
}
fn with_matcher(mut self, matcher: &str) -> Self {
self.hook_matcher = Some(matcher.into());
self
}
}
#[async_trait]
impl Hook for PassthroughHook {
fn name(&self) -> &str {
&self.hook_name
}
fn events(&self) -> &[HookEvent] {
&self.hook_events
}
fn matcher(&self) -> Option<&str> {
self.hook_matcher.as_deref()
}
async fn on_event(&self, _input: &HookInput) -> HookOutput {
HookOutput::passthrough()
}
}
#[test]
fn test_registry_new_empty() {
let reg = HookRegistry::new();
assert!(reg.is_empty());
assert_eq!(reg.len(), 0);
}
#[test]
fn test_registry_register_and_len() {
let mut reg = HookRegistry::new();
reg.register(
Arc::new(PassthroughHook::new("a")),
HookSource::Programmatic,
0,
);
reg.register(
Arc::new(PassthroughHook::new("b")),
HookSource::Programmatic,
0,
);
assert_eq!(reg.len(), 2);
}
#[test]
fn test_registry_remove() {
let mut reg = HookRegistry::new();
reg.register(
Arc::new(PassthroughHook::new("a")),
HookSource::Programmatic,
0,
);
reg.register(
Arc::new(PassthroughHook::new("b")),
HookSource::Programmatic,
0,
);
reg.remove("a");
assert_eq!(reg.len(), 1);
assert_eq!(reg.hooks[0].hook.name(), "b");
}
#[test]
fn test_registry_matching_by_event() {
let mut reg = HookRegistry::new();
reg.register(
Arc::new(PassthroughHook::new("pre_only").with_events(vec![HookEvent::PreToolUse])),
HookSource::Programmatic,
0,
);
reg.register(
Arc::new(PassthroughHook::new("post_only").with_events(vec![HookEvent::PostToolUse])),
HookSource::Programmatic,
0,
);
reg.register(
Arc::new(PassthroughHook::new("all_events")),
HookSource::Programmatic,
0,
);
let input = HookInput::PreToolUse {
tool_name: "bash".into(),
tool_input: json!({}),
call_id: ToolCallId::new("c1"),
};
let matched = reg.matching(&input);
assert_eq!(matched.len(), 2);
let names: Vec<&str> = matched.iter().map(|h| h.hook.name()).collect();
assert!(names.contains(&"pre_only"));
assert!(names.contains(&"all_events"));
assert!(!names.contains(&"post_only"));
}
#[test]
fn test_registry_matching_by_matcher() {
let mut reg = HookRegistry::new();
reg.register(
Arc::new(
PassthroughHook::new("bash_only")
.with_events(vec![HookEvent::PreToolUse])
.with_matcher("bash"),
),
HookSource::Programmatic,
0,
);
reg.register(
Arc::new(
PassthroughHook::new("write_family")
.with_events(vec![HookEvent::PreToolUse])
.with_matcher("write_*"),
),
HookSource::Programmatic,
0,
);
let input_bash = HookInput::PreToolUse {
tool_name: "bash".into(),
tool_input: json!({}),
call_id: ToolCallId::new("c1"),
};
let matched = reg.matching(&input_bash);
assert_eq!(matched.len(), 1);
assert_eq!(matched[0].hook.name(), "bash_only");
let input_write = HookInput::PreToolUse {
tool_name: "write_file".into(),
tool_input: json!({}),
call_id: ToolCallId::new("c2"),
};
let matched = reg.matching(&input_write);
assert_eq!(matched.len(), 1);
assert_eq!(matched[0].hook.name(), "write_family");
let input_read = HookInput::PreToolUse {
tool_name: "read_file".into(),
tool_input: json!({}),
call_id: ToolCallId::new("c3"),
};
let matched = reg.matching(&input_read);
assert!(matched.is_empty());
}
#[test]
fn test_registry_matching_non_tool_event_with_matcher() {
let mut reg = HookRegistry::new();
reg.register(
Arc::new(
PassthroughHook::new("h")
.with_events(vec![HookEvent::SessionStart])
.with_matcher("bash"),
),
HookSource::Programmatic,
0,
);
let input = HookInput::SessionStart {
session_id: SessionId::new(),
};
let matched = reg.matching(&input);
assert!(matched.is_empty());
}
#[test]
fn test_registry_priority_order() {
let mut reg = HookRegistry::new();
reg.register(
Arc::new(PassthroughHook::new("low")),
HookSource::Programmatic,
10,
);
reg.register(
Arc::new(PassthroughHook::new("high")),
HookSource::Programmatic,
-10,
);
reg.register(
Arc::new(PassthroughHook::new("mid")),
HookSource::Programmatic,
0,
);
let input = HookInput::SessionStart {
session_id: SessionId::new(),
};
let matched = reg.matching(&input);
assert_eq!(matched[0].hook.name(), "high");
assert_eq!(matched[1].hook.name(), "mid");
assert_eq!(matched[2].hook.name(), "low");
}
#[test]
fn test_registry_has_hooks_for() {
let mut reg = HookRegistry::new();
reg.register(
Arc::new(PassthroughHook::new("pre_only").with_events(vec![HookEvent::PreToolUse])),
HookSource::Programmatic,
0,
);
assert!(reg.has_hooks_for(HookEvent::PreToolUse));
assert!(!reg.has_hooks_for(HookEvent::PostToolUse));
}
#[test]
fn test_registry_has_hooks_for_all_events() {
let mut reg = HookRegistry::new();
reg.register(
Arc::new(PassthroughHook::new("global")),
HookSource::Programmatic,
0,
);
for event in ALL_HOOK_EVENTS {
assert!(reg.has_hooks_for(*event));
}
}
#[test]
fn test_hook_source_serde_roundtrip() {
for source in [
HookSource::Settings,
HookSource::Project,
HookSource::Plugin {
name: "linter".into(),
},
HookSource::Programmatic,
HookSource::Session,
] {
let json_str = serde_json::to_string(&source).unwrap();
let restored: HookSource = serde_json::from_str(&json_str).unwrap();
assert_eq!(source, restored);
}
}
#[tokio::test]
async fn test_hook_trait_async_execution() {
struct DenyBashHook;
#[async_trait]
impl Hook for DenyBashHook {
fn name(&self) -> &str {
"deny_bash"
}
fn events(&self) -> &[HookEvent] {
&[HookEvent::PreToolUse]
}
fn matcher(&self) -> Option<&str> {
Some("bash")
}
async fn on_event(&self, input: &HookInput) -> HookOutput {
if let HookInput::PreToolUse { tool_input, .. } = input {
let cmd = tool_input["command"].as_str().unwrap_or("");
if cmd.contains("rm -rf") {
return HookOutput::deny("dangerous command");
}
}
HookOutput::passthrough()
}
}
let hook: Arc<dyn Hook> = Arc::new(DenyBashHook);
let safe_input = HookInput::PreToolUse {
tool_name: "bash".into(),
tool_input: json!({"command": "ls -la"}),
call_id: ToolCallId::new("c1"),
};
let output = hook.on_event(&safe_input).await;
assert!(!output.has_decision());
let dangerous_input = HookInput::PreToolUse {
tool_name: "bash".into(),
tool_input: json!({"command": "rm -rf /"}),
call_id: ToolCallId::new("c2"),
};
let output = hook.on_event(&dangerous_input).await;
assert!(output.permission.as_ref().unwrap().is_deny());
}
#[tokio::test]
async fn test_hook_trait_dyn_dispatch() {
let hook: Arc<dyn Hook> = Arc::new(PassthroughHook::new("test"));
assert_eq!(hook.name(), "test");
let input = HookInput::SessionStart {
session_id: SessionId::new(),
};
let output = hook.on_event(&input).await;
assert!(!output.has_decision());
}
}