use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub enum HookResult<T> {
Continue(T),
Cancel(String),
}
impl<T> HookResult<T> {
pub fn is_cancel(&self) -> bool {
matches!(self, HookResult::Cancel(_))
}
}
#[async_trait::async_trait]
pub trait HookHandler: Send + Sync {
fn name(&self) -> &str;
fn priority(&self) -> i32 {
0
}
async fn before_llm_call(
&self,
messages: Vec<Value>,
model: String,
) -> HookResult<(Vec<Value>, String)> {
HookResult::Continue((messages, model))
}
async fn before_tool_call(
&self,
tool_name: String,
args: Value,
) -> HookResult<(String, Value)> {
HookResult::Continue((tool_name, args))
}
async fn on_session_start(&self, _session_id: &str, _channel: &str) {}
async fn on_session_end(&self, _session_id: &str) {}
async fn on_after_tool_call(&self, _tool: &str, _success: bool, _duration: Duration) {}
async fn on_llm_response(&self, _model: &str, _token_count: u32, _duration: Duration) {}
}
pub struct HookRegistry {
hooks: Vec<Arc<dyn HookHandler>>,
}
impl HookRegistry {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
pub fn register(&mut self, handler: Arc<dyn HookHandler>) {
self.hooks.push(handler);
self.hooks.sort_by_key(|h| h.priority());
tracing::debug!(
"hook registered (priority {})",
self.hooks.last().map(|h| h.priority()).unwrap_or(0)
);
}
pub async fn run_before_llm_call(
&self,
messages: Vec<Value>,
model: String,
) -> Option<(Vec<Value>, String)> {
let mut current = (messages, model);
for hook in &self.hooks {
match hook
.before_llm_call(current.0.clone(), current.1.clone())
.await
{
HookResult::Continue(next) => current = next,
HookResult::Cancel(reason) => {
tracing::info!(hook = hook.name(), "before_llm_call cancelled: {}", reason);
return None;
}
}
}
Some(current)
}
pub async fn run_before_tool_call(
&self,
tool_name: String,
args: Value,
) -> Option<(String, Value)> {
let mut current = (tool_name, args);
for hook in &self.hooks {
match hook
.before_tool_call(current.0.clone(), current.1.clone())
.await
{
HookResult::Continue(next) => current = next,
HookResult::Cancel(reason) => {
tracing::info!(hook = hook.name(), "before_tool_call cancelled: {}", reason);
return None;
}
}
}
Some(current)
}
pub async fn fire_session_start(&self, session_id: &str, channel: &str) {
for hook in &self.hooks {
hook.on_session_start(session_id, channel).await;
}
}
pub async fn fire_session_end(&self, session_id: &str) {
for hook in &self.hooks {
hook.on_session_end(session_id).await;
}
}
pub async fn fire_after_tool_call(&self, tool: &str, success: bool, duration: Duration) {
for hook in &self.hooks {
hook.on_after_tool_call(tool, success, duration).await;
}
}
pub async fn fire_llm_response(&self, model: &str, token_count: u32, duration: Duration) {
for hook in &self.hooks {
hook.on_llm_response(model, token_count, duration).await;
}
}
pub fn len(&self) -> usize {
self.hooks.len()
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
}
impl Default for HookRegistry {
fn default() -> Self {
Self::new()
}
}
pub type SharedHookRegistry = Arc<RwLock<HookRegistry>>;
pub fn new_hook_registry() -> SharedHookRegistry {
Arc::new(RwLock::new(HookRegistry::new()))
}
#[cfg(test)]
mod tests {
use super::*;
struct NoopHook {
name: String,
priority: i32,
}
#[async_trait::async_trait]
impl HookHandler for NoopHook {
fn name(&self) -> &str {
&self.name
}
fn priority(&self) -> i32 {
self.priority
}
}
#[tokio::test]
async fn empty_registry_passes_through() {
let reg = HookRegistry::new();
let result = reg
.run_before_tool_call("shell".into(), serde_json::json!({}))
.await;
assert!(result.is_some());
let (name, _) = result.unwrap();
assert_eq!(name, "shell");
}
#[tokio::test]
async fn hook_result_cancel_short_circuits() {
struct CancelHook;
#[async_trait::async_trait]
impl HookHandler for CancelHook {
fn name(&self) -> &str {
"cancel"
}
async fn before_tool_call(
&self,
_name: String,
_args: Value,
) -> HookResult<(String, Value)> {
HookResult::Cancel("blocked".into())
}
}
let mut reg = HookRegistry::new();
reg.register(Arc::new(CancelHook) as Arc<dyn HookHandler>);
let result = reg
.run_before_tool_call("shell".into(), serde_json::json!({}))
.await;
assert!(result.is_none());
}
#[tokio::test]
async fn hooks_sorted_by_priority() {
let mut reg = HookRegistry::new();
reg.register(Arc::new(NoopHook {
name: "low".into(),
priority: 10,
}) as Arc<dyn HookHandler>);
reg.register(Arc::new(NoopHook {
name: "high".into(),
priority: -10,
}) as Arc<dyn HookHandler>);
assert_eq!(reg.hooks[0].priority(), -10);
assert_eq!(reg.hooks[1].priority(), 10);
}
}