use super::events::{HookEvent, HookEventType};
use super::matcher::HookMatcher;
use super::{HookAction, HookResponse};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::mpsc;
use crate::error::{read_or_recover, write_or_recover};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookConfig {
#[serde(default = "default_priority")]
pub priority: i32,
#[serde(default = "default_timeout")]
pub timeout_ms: u64,
#[serde(default)]
pub async_execution: bool,
#[serde(default)]
pub max_retries: u32,
}
fn default_priority() -> i32 {
100
}
fn default_timeout() -> u64 {
30000
}
impl Default for HookConfig {
fn default() -> Self {
Self {
priority: default_priority(),
timeout_ms: default_timeout(),
async_execution: false,
max_retries: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hook {
pub id: String,
pub event_type: HookEventType,
#[serde(skip_serializing_if = "Option::is_none")]
pub matcher: Option<HookMatcher>,
#[serde(default)]
pub config: HookConfig,
}
impl Hook {
pub fn new(id: impl Into<String>, event_type: HookEventType) -> Self {
Self {
id: id.into(),
event_type,
matcher: None,
config: HookConfig::default(),
}
}
pub fn with_matcher(mut self, matcher: HookMatcher) -> Self {
self.matcher = Some(matcher);
self
}
pub fn with_config(mut self, config: HookConfig) -> Self {
self.config = config;
self
}
pub fn matches(&self, event: &HookEvent) -> bool {
if event.event_type() != self.event_type {
return false;
}
if let Some(ref matcher) = self.matcher {
matcher.matches(event)
} else {
true
}
}
}
#[derive(Debug, Clone)]
pub enum HookResult {
Continue(Option<serde_json::Value>),
Block(String),
Retry(u64),
Skip,
}
impl HookResult {
pub fn continue_() -> Self {
Self::Continue(None)
}
pub fn continue_with(modified: serde_json::Value) -> Self {
Self::Continue(Some(modified))
}
pub fn block(reason: impl Into<String>) -> Self {
Self::Block(reason.into())
}
pub fn retry(delay_ms: u64) -> Self {
Self::Retry(delay_ms)
}
pub fn skip() -> Self {
Self::Skip
}
pub fn is_continue(&self) -> bool {
matches!(self, Self::Continue(_))
}
pub fn is_block(&self) -> bool {
matches!(self, Self::Block(_))
}
}
pub trait HookHandler: Send + Sync {
fn handle(&self, event: &HookEvent) -> HookResponse;
}
#[async_trait::async_trait]
pub trait HookExecutor: Send + Sync + std::fmt::Debug {
async fn fire(&self, event: &HookEvent) -> HookResult;
}
pub struct HookEngine {
hooks: Arc<RwLock<HashMap<String, Hook>>>,
handlers: Arc<RwLock<HashMap<String, Arc<dyn HookHandler>>>>,
event_tx: Option<mpsc::Sender<HookEvent>>,
}
impl std::fmt::Debug for HookEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HookEngine")
.field("hooks_count", &read_or_recover(&self.hooks).len())
.field("handlers_count", &read_or_recover(&self.handlers).len())
.field("has_event_channel", &self.event_tx.is_some())
.finish()
}
}
impl Default for HookEngine {
fn default() -> Self {
Self::new()
}
}
impl HookEngine {
pub fn new() -> Self {
Self {
hooks: Arc::new(RwLock::new(HashMap::new())),
handlers: Arc::new(RwLock::new(HashMap::new())),
event_tx: None,
}
}
pub fn with_event_channel(mut self, tx: mpsc::Sender<HookEvent>) -> Self {
self.event_tx = Some(tx);
self
}
pub fn register(&self, hook: Hook) {
let mut hooks = write_or_recover(&self.hooks);
hooks.insert(hook.id.clone(), hook);
}
pub fn unregister(&self, hook_id: &str) -> Option<Hook> {
let mut hooks = write_or_recover(&self.hooks);
hooks.remove(hook_id)
}
pub fn register_handler(&self, hook_id: &str, handler: Arc<dyn HookHandler>) {
let mut handlers = write_or_recover(&self.handlers);
handlers.insert(hook_id.to_string(), handler);
}
pub fn unregister_handler(&self, hook_id: &str) {
let mut handlers = write_or_recover(&self.handlers);
handlers.remove(hook_id);
}
pub fn matching_hooks(&self, event: &HookEvent) -> Vec<Hook> {
let hooks = read_or_recover(&self.hooks);
let mut matching: Vec<Hook> = hooks
.values()
.filter(|h| h.matches(event))
.cloned()
.collect();
matching.sort_by_key(|h| h.config.priority);
matching
}
pub async fn fire(&self, event: &HookEvent) -> HookResult {
if let Some(ref tx) = self.event_tx {
let _ = tx.send(event.clone()).await;
}
let matching_hooks = self.matching_hooks(event);
if matching_hooks.is_empty() {
return HookResult::continue_();
}
let mut last_modified: Option<serde_json::Value> = None;
for hook in matching_hooks {
let result = self.execute_hook(&hook, event).await;
match result {
HookResult::Continue(modified) => {
if modified.is_some() {
last_modified = modified;
}
}
HookResult::Block(reason) => {
return HookResult::Block(reason);
}
HookResult::Retry(delay) => {
return HookResult::Retry(delay);
}
HookResult::Skip => {
return HookResult::Continue(None);
}
}
}
HookResult::Continue(last_modified)
}
async fn execute_hook(&self, hook: &Hook, event: &HookEvent) -> HookResult {
let handler = {
let handlers = read_or_recover(&self.handlers);
handlers.get(&hook.id).cloned()
};
match handler {
Some(h) => {
let response = if hook.config.async_execution {
let h = h.clone();
let event = event.clone();
tokio::spawn(async move {
h.handle(&event);
});
HookResponse::continue_()
} else {
let timeout = std::time::Duration::from_millis(hook.config.timeout_ms);
let h = h.clone();
let event = event.clone();
match tokio::time::timeout(timeout, async move { h.handle(&event) }).await {
Ok(response) => response,
Err(_) => {
HookResponse::continue_()
}
}
};
self.response_to_result(response)
}
None => {
HookResult::continue_()
}
}
}
fn response_to_result(&self, response: HookResponse) -> HookResult {
match response.action {
HookAction::Continue => HookResult::Continue(response.modified),
HookAction::Block => {
HookResult::Block(response.reason.unwrap_or_else(|| "Blocked".to_string()))
}
HookAction::Retry => HookResult::Retry(response.retry_delay_ms.unwrap_or(1000)),
HookAction::Skip => HookResult::Skip,
}
}
pub fn hook_count(&self) -> usize {
read_or_recover(&self.hooks).len()
}
pub fn get_hook(&self, id: &str) -> Option<Hook> {
read_or_recover(&self.hooks).get(id).cloned()
}
pub fn all_hooks(&self) -> Vec<Hook> {
read_or_recover(&self.hooks).values().cloned().collect()
}
}
#[async_trait]
impl HookExecutor for HookEngine {
async fn fire(&self, event: &HookEvent) -> HookResult {
self.fire(event).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hooks::events::PreToolUseEvent;
fn make_pre_tool_event(session_id: &str, tool: &str) -> HookEvent {
HookEvent::PreToolUse(PreToolUseEvent {
session_id: session_id.to_string(),
tool: tool.to_string(),
args: serde_json::json!({}),
working_directory: "/workspace".to_string(),
recent_tools: vec![],
})
}
#[test]
fn test_hook_config_default() {
let config = HookConfig::default();
assert_eq!(config.priority, 100);
assert_eq!(config.timeout_ms, 30000);
assert!(!config.async_execution);
assert_eq!(config.max_retries, 0);
}
#[test]
fn test_hook_new() {
let hook = Hook::new("test-hook", HookEventType::PreToolUse);
assert_eq!(hook.id, "test-hook");
assert_eq!(hook.event_type, HookEventType::PreToolUse);
assert!(hook.matcher.is_none());
}
#[test]
fn test_hook_with_matcher() {
let hook = Hook::new("test-hook", HookEventType::PreToolUse)
.with_matcher(HookMatcher::tool("Bash"));
assert!(hook.matcher.is_some());
assert_eq!(hook.matcher.unwrap().tool, Some("Bash".to_string()));
}
#[test]
fn test_hook_matches_event_type() {
let hook = Hook::new("test-hook", HookEventType::PreToolUse);
let pre_event = make_pre_tool_event("s1", "Bash");
assert!(hook.matches(&pre_event));
let post_event = HookEvent::PostToolUse(crate::hooks::events::PostToolUseEvent {
session_id: "s1".to_string(),
tool: "Bash".to_string(),
args: serde_json::json!({}),
result: crate::hooks::events::ToolResultData {
success: true,
output: "".to_string(),
exit_code: Some(0),
duration_ms: 100,
},
});
assert!(!hook.matches(&post_event));
}
#[test]
fn test_hook_matches_with_matcher() {
let hook = Hook::new("test-hook", HookEventType::PreToolUse)
.with_matcher(HookMatcher::tool("Bash"));
let bash_event = make_pre_tool_event("s1", "Bash");
let read_event = make_pre_tool_event("s1", "Read");
assert!(hook.matches(&bash_event));
assert!(!hook.matches(&read_event));
}
#[test]
fn test_hook_result_constructors() {
let cont = HookResult::continue_();
assert!(cont.is_continue());
assert!(!cont.is_block());
let cont_with = HookResult::continue_with(serde_json::json!({"key": "value"}));
assert!(cont_with.is_continue());
let block = HookResult::block("Blocked");
assert!(block.is_block());
assert!(!block.is_continue());
let retry = HookResult::retry(1000);
assert!(!retry.is_continue());
assert!(!retry.is_block());
let skip = HookResult::skip();
assert!(!skip.is_continue());
assert!(!skip.is_block());
}
#[test]
fn test_engine_register_unregister() {
let engine = HookEngine::new();
let hook = Hook::new("test-hook", HookEventType::PreToolUse);
engine.register(hook);
assert_eq!(engine.hook_count(), 1);
assert!(engine.get_hook("test-hook").is_some());
let removed = engine.unregister("test-hook");
assert!(removed.is_some());
assert_eq!(engine.hook_count(), 0);
}
#[test]
fn test_engine_matching_hooks() {
let engine = HookEngine::new();
engine.register(
Hook::new("hook-1", HookEventType::PreToolUse).with_config(HookConfig {
priority: 10,
..Default::default()
}),
);
engine.register(
Hook::new("hook-2", HookEventType::PreToolUse)
.with_matcher(HookMatcher::tool("Bash"))
.with_config(HookConfig {
priority: 5,
..Default::default()
}),
);
engine.register(Hook::new("hook-3", HookEventType::PostToolUse));
let event = make_pre_tool_event("s1", "Bash");
let matching = engine.matching_hooks(&event);
assert_eq!(matching.len(), 2);
assert_eq!(matching[0].id, "hook-2");
assert_eq!(matching[1].id, "hook-1");
}
#[tokio::test]
async fn test_engine_fire_no_hooks() {
let engine = HookEngine::new();
let event = make_pre_tool_event("s1", "Bash");
let result = engine.fire(&event).await;
assert!(result.is_continue());
}
#[tokio::test]
async fn test_engine_fire_no_handler() {
let engine = HookEngine::new();
engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
let event = make_pre_tool_event("s1", "Bash");
let result = engine.fire(&event).await;
assert!(result.is_continue());
}
struct ContinueHandler;
impl HookHandler for ContinueHandler {
fn handle(&self, _event: &HookEvent) -> HookResponse {
HookResponse::continue_()
}
}
struct BlockHandler {
reason: String,
}
impl HookHandler for BlockHandler {
fn handle(&self, _event: &HookEvent) -> HookResponse {
HookResponse::block(&self.reason)
}
}
#[tokio::test]
async fn test_engine_fire_with_continue_handler() {
let engine = HookEngine::new();
engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
engine.register_handler("test-hook", Arc::new(ContinueHandler));
let event = make_pre_tool_event("s1", "Bash");
let result = engine.fire(&event).await;
assert!(result.is_continue());
}
#[tokio::test]
async fn test_engine_fire_with_block_handler() {
let engine = HookEngine::new();
engine.register(Hook::new("test-hook", HookEventType::PreToolUse));
engine.register_handler(
"test-hook",
Arc::new(BlockHandler {
reason: "Dangerous command".to_string(),
}),
);
let event = make_pre_tool_event("s1", "Bash");
let result = engine.fire(&event).await;
assert!(result.is_block());
if let HookResult::Block(reason) = result {
assert_eq!(reason, "Dangerous command");
}
}
#[tokio::test]
async fn test_engine_fire_priority_order() {
let engine = HookEngine::new();
engine.register(
Hook::new("block-hook", HookEventType::PreToolUse).with_config(HookConfig {
priority: 5, ..Default::default()
}),
);
engine.register(
Hook::new("continue-hook", HookEventType::PreToolUse).with_config(HookConfig {
priority: 10,
..Default::default()
}),
);
engine.register_handler(
"block-hook",
Arc::new(BlockHandler {
reason: "Blocked first".to_string(),
}),
);
engine.register_handler("continue-hook", Arc::new(ContinueHandler));
let event = make_pre_tool_event("s1", "Bash");
let result = engine.fire(&event).await;
assert!(result.is_block());
}
#[test]
fn test_hook_serialization() {
let hook = Hook::new("test-hook", HookEventType::PreToolUse)
.with_matcher(HookMatcher::tool("Bash"))
.with_config(HookConfig {
priority: 50,
timeout_ms: 5000,
async_execution: true,
max_retries: 3,
});
let json = serde_json::to_string(&hook).unwrap();
assert!(json.contains("test-hook"));
assert!(json.contains("pre_tool_use"));
assert!(json.contains("Bash"));
let parsed: Hook = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "test-hook");
assert_eq!(parsed.event_type, HookEventType::PreToolUse);
assert_eq!(parsed.config.priority, 50);
}
#[test]
fn test_all_hooks() {
let engine = HookEngine::new();
engine.register(Hook::new("hook-1", HookEventType::PreToolUse));
engine.register(Hook::new("hook-2", HookEventType::PostToolUse));
let all = engine.all_hooks();
assert_eq!(all.len(), 2);
}
fn make_skill_load_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
HookEvent::SkillLoad(crate::hooks::events::SkillLoadEvent {
skill_name: skill_name.to_string(),
tool_names: tools.iter().map(|s| s.to_string()).collect(),
version: Some("1.0.0".to_string()),
description: Some("Test skill".to_string()),
loaded_at: 1234567890,
})
}
fn make_skill_unload_event(skill_name: &str, tools: Vec<&str>) -> HookEvent {
HookEvent::SkillUnload(crate::hooks::events::SkillUnloadEvent {
skill_name: skill_name.to_string(),
tool_names: tools.iter().map(|s| s.to_string()).collect(),
duration_ms: 60000,
})
}
#[tokio::test]
async fn test_engine_fire_skill_load() {
let engine = HookEngine::new();
engine.register(Hook::new("skill-load-hook", HookEventType::SkillLoad));
engine.register_handler("skill-load-hook", Arc::new(ContinueHandler));
let event = make_skill_load_event("my-skill", vec!["tool1", "tool2"]);
let result = engine.fire(&event).await;
assert!(result.is_continue());
}
#[tokio::test]
async fn test_engine_fire_skill_unload() {
let engine = HookEngine::new();
engine.register(Hook::new("skill-unload-hook", HookEventType::SkillUnload));
engine.register_handler("skill-unload-hook", Arc::new(ContinueHandler));
let event = make_skill_unload_event("my-skill", vec!["tool1", "tool2"]);
let result = engine.fire(&event).await;
assert!(result.is_continue());
}
#[tokio::test]
async fn test_engine_skill_hook_with_matcher() {
let engine = HookEngine::new();
engine.register(
Hook::new("specific-skill-hook", HookEventType::SkillLoad)
.with_matcher(HookMatcher::skill("my-skill")),
);
engine.register_handler(
"specific-skill-hook",
Arc::new(BlockHandler {
reason: "Skill blocked".to_string(),
}),
);
let matching_event = make_skill_load_event("my-skill", vec!["tool1"]);
let result = engine.fire(&matching_event).await;
assert!(result.is_block());
let non_matching_event = make_skill_load_event("other-skill", vec!["tool1"]);
let result = engine.fire(&non_matching_event).await;
assert!(result.is_continue());
}
#[tokio::test]
async fn test_engine_skill_hook_pattern_matcher() {
let engine = HookEngine::new();
engine.register(
Hook::new("test-skill-hook", HookEventType::SkillLoad)
.with_matcher(HookMatcher::skill("test-*")),
);
engine.register_handler(
"test-skill-hook",
Arc::new(BlockHandler {
reason: "Test skill blocked".to_string(),
}),
);
let test_skill = make_skill_load_event("test-alpha", vec!["tool1"]);
let result = engine.fire(&test_skill).await;
assert!(result.is_block());
let test_skill2 = make_skill_load_event("test-beta", vec!["tool1"]);
let result = engine.fire(&test_skill2).await;
assert!(result.is_block());
let prod_skill = make_skill_load_event("prod-skill", vec!["tool1"]);
let result = engine.fire(&prod_skill).await;
assert!(result.is_continue());
}
}