use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum HookResult {
Continue,
Block {
reason: String,
details: Option<String>,
},
Modify(Value),
}
#[async_trait]
pub trait ToolHook: Send + Sync {
fn name(&self) -> &str;
fn is_enabled(&self) -> bool;
fn applies_to(&self) -> Vec<&str> {
Vec::new()
}
fn applies_to_tool(&self, tool_name: &str) -> bool {
let applies_to = self.applies_to();
applies_to.is_empty() || applies_to.iter().any(|t| *t == tool_name)
}
async fn pre_execute(&self, tool_name: &str, params: &Value) -> Result<HookResult>;
async fn post_execute(&self, tool_name: &str, params: &Value, result: &str) -> Result<String>;
}
pub struct HookRegistry {
hooks: Vec<Box<dyn ToolHook>>,
}
impl Default for HookRegistry {
fn default() -> Self {
Self::new()
}
}
impl HookRegistry {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
pub fn with_defaults() -> Self {
Self::new()
}
pub fn register(&mut self, hook: Box<dyn ToolHook>) {
self.hooks.push(hook);
}
pub fn hooks(&self) -> &[Box<dyn ToolHook>] {
&self.hooks
}
pub async fn pre_execute(&self, tool_name: &str, params: &Value) -> Result<HookResult> {
let mut current_params = params.clone();
for hook in &self.hooks {
if !hook.is_enabled() || !hook.applies_to_tool(tool_name) {
continue;
}
let result = hook.pre_execute(tool_name, ¤t_params).await?;
match result {
HookResult::Block { .. } => {
return Ok(result);
}
HookResult::Modify(new_params) => {
current_params = new_params;
}
HookResult::Continue => {
}
}
}
if current_params != *params {
Ok(HookResult::Modify(current_params))
} else {
Ok(HookResult::Continue)
}
}
pub async fn post_execute(&self, tool_name: &str, params: &Value, result: &str) -> Result<String> {
let mut current_result = result.to_string();
for hook in &self.hooks {
if !hook.is_enabled() || !hook.applies_to_tool(tool_name) {
continue;
}
current_result = hook.post_execute(tool_name, params, ¤t_result).await?;
}
Ok(current_result)
}
}
static GLOBAL_HOOK_REGISTRY: std::sync::OnceLock<Arc<HookRegistry>> = std::sync::OnceLock::new();
pub fn global_hook_registry() -> Arc<HookRegistry> {
GLOBAL_HOOK_REGISTRY
.get_or_init(|| Arc::new(HookRegistry::with_defaults()))
.clone()
}
pub fn set_global_hook_registry(registry: HookRegistry) {
let _ = GLOBAL_HOOK_REGISTRY.set(Arc::new(registry));
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHook {
enabled: bool,
block: bool,
}
#[async_trait]
impl ToolHook for TestHook {
fn name(&self) -> &str {
"test_hook"
}
fn is_enabled(&self) -> bool {
self.enabled
}
fn applies_to(&self) -> Vec<&str> {
vec!["write"]
}
async fn pre_execute(&self, _tool_name: &str, _params: &Value) -> Result<HookResult> {
if self.block {
Ok(HookResult::Block {
reason: "Test block".to_string(),
details: Some("Test details".to_string()),
})
} else {
Ok(HookResult::Continue)
}
}
async fn post_execute(&self, _tool_name: &str, _params: &Value, result: &str) -> Result<String> {
Ok(format!("{} [hooked]", result))
}
}
#[tokio::test]
async fn test_hook_registry_pre_execute_continue() {
let mut registry = HookRegistry::new();
registry.register(Box::new(TestHook { enabled: true, block: false }));
let result = registry.pre_execute("write", &serde_json::json!({})).await;
assert!(matches!(result.unwrap(), HookResult::Continue));
}
#[tokio::test]
async fn test_hook_registry_pre_execute_block() {
let mut registry = HookRegistry::new();
registry.register(Box::new(TestHook { enabled: true, block: true }));
let result = registry.pre_execute("write", &serde_json::json!({})).await;
assert!(matches!(result.unwrap(), HookResult::Block { .. }));
}
#[tokio::test]
async fn test_hook_registry_disabled_hook() {
let mut registry = HookRegistry::new();
registry.register(Box::new(TestHook { enabled: false, block: true }));
let result = registry.pre_execute("write", &serde_json::json!({})).await;
assert!(matches!(result.unwrap(), HookResult::Continue));
}
#[tokio::test]
async fn test_hook_registry_tool_filter() {
let mut registry = HookRegistry::new();
registry.register(Box::new(TestHook { enabled: true, block: true }));
let result = registry.pre_execute("read", &serde_json::json!({})).await;
assert!(matches!(result.unwrap(), HookResult::Continue));
let result = registry.pre_execute("write", &serde_json::json!({})).await;
assert!(matches!(result.unwrap(), HookResult::Block { .. }));
}
#[tokio::test]
async fn test_hook_registry_post_execute() {
let mut registry = HookRegistry::new();
registry.register(Box::new(TestHook { enabled: true, block: false }));
let result = registry.post_execute("write", &serde_json::json!({}), "original").await;
assert_eq!(result.unwrap(), "original [hooked]");
}
}