use std::sync::Arc;
use crate::error::Result;
use crate::types::{HookCallback, HookContext, HookDecision, HookMatcher, HookOutput};
pub struct HookManager {
matchers: Vec<HookMatcher>,
}
impl HookManager {
pub fn new() -> Self {
Self {
matchers: Vec::new(),
}
}
pub fn register(&mut self, matcher: HookMatcher) {
self.matchers.push(matcher);
}
pub async fn invoke(
&self,
event_data: serde_json::Value,
tool_name: Option<String>,
context: HookContext,
) -> Result<HookOutput> {
let mut output = HookOutput::default();
for matcher in &self.matchers {
if Self::matches(&matcher.matcher, &tool_name) {
for hook in &matcher.hooks {
let result = hook(event_data.clone(), tool_name.clone(), context.clone()).await?;
if result.decision.is_some() {
output.decision = result.decision;
}
if result.system_message.is_some() {
output.system_message = result.system_message;
}
if result.hook_specific_output.is_some() {
output.hook_specific_output = result.hook_specific_output;
}
if matches!(output.decision, Some(HookDecision::Block)) {
return Ok(output);
}
}
}
}
Ok(output)
}
fn matches(matcher: &Option<String>, tool_name: &Option<String>) -> bool {
match (matcher, tool_name) {
(None, _) => true, (Some(pattern), Some(name)) => {
if pattern == "*" {
return true;
}
pattern == name || pattern.split('|').any(|p| p == name)
}
(Some(_), None) => false,
}
}
pub fn callback<F, Fut>(f: F) -> HookCallback
where
F: Fn(serde_json::Value, Option<String>, HookContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<HookOutput>> + Send + 'static,
{
Arc::new(move |event_data, tool_name, context| {
Box::pin(f(event_data, tool_name, context))
})
}
}
impl Default for HookManager {
fn default() -> Self {
Self::new()
}
}
pub struct HookMatcherBuilder {
matcher: Option<String>,
hooks: Vec<HookCallback>,
}
impl HookMatcherBuilder {
pub fn new(pattern: Option<impl Into<String>>) -> Self {
Self {
matcher: pattern.map(|p| p.into()),
hooks: Vec::new(),
}
}
pub fn add_hook(mut self, hook: HookCallback) -> Self {
self.hooks.push(hook);
self
}
pub fn build(self) -> HookMatcher {
HookMatcher {
matcher: self.matcher,
hooks: self.hooks,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hook_manager() {
let mut manager = HookManager::new();
let hook = HookManager::callback(|_event_data, _tool_name, _context| async {
Ok(HookOutput::default())
});
let matcher = HookMatcherBuilder::new(Some("*")).add_hook(hook).build();
manager.register(matcher);
let context = HookContext {};
let result = manager
.invoke(serde_json::json!({}), Some("test".to_string()), context)
.await;
assert!(result.is_ok());
}
#[test]
fn test_matcher_wildcard() {
assert!(HookManager::matches(
&Some("*".to_string()),
&Some("any_tool".to_string())
));
assert!(HookManager::matches(&None, &Some("any_tool".to_string())));
}
#[test]
fn test_matcher_specific() {
assert!(HookManager::matches(
&Some("Bash".to_string()),
&Some("Bash".to_string())
));
assert!(!HookManager::matches(
&Some("Bash".to_string()),
&Some("Write".to_string())
));
}
#[test]
fn test_matcher_pattern() {
assert!(HookManager::matches(
&Some("Write|Edit".to_string()),
&Some("Write".to_string())
));
assert!(HookManager::matches(
&Some("Write|Edit".to_string()),
&Some("Edit".to_string())
));
assert!(!HookManager::matches(
&Some("Write|Edit".to_string()),
&Some("Bash".to_string())
));
}
}