use async_trait::async_trait;
use serde_json::Value;
use std::sync::Arc;
use crate::{
agent::{AgentError, AgentInput, AgentOutput},
provider::types::{ChatMessage, ChatResponse},
tool::types::ToolResult,
};
#[derive(Debug, Clone)]
pub enum HookResult<T> {
Continue(T),
Cancel(String),
}
impl<T> HookResult<T> {
pub fn into_option(self) -> Option<T> {
match self {
HookResult::Continue(v) => Some(v),
HookResult::Cancel(_) => None,
}
}
pub fn is_continue(&self) -> bool {
matches!(self, HookResult::Continue(_))
}
pub fn is_cancel(&self) -> bool {
matches!(self, HookResult::Cancel(_))
}
pub fn map<F, U>(self, f: F) -> HookResult<U>
where
F: FnOnce(T) -> U,
{
match self {
HookResult::Continue(v) => HookResult::Continue(f(v)),
HookResult::Cancel(msg) => HookResult::Cancel(msg),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct HookPriority(pub i32);
impl HookPriority {
pub const HIGHEST: Self = Self(i32::MAX);
pub const HIGH: Self = Self(100);
pub const NORMAL: Self = Self(0);
pub const LOW: Self = Self(-100);
pub const LOWEST: Self = Self(i32::MIN);
}
impl Default for HookPriority {
fn default() -> Self {
Self::NORMAL
}
}
impl From<i32> for HookPriority {
fn from(value: i32) -> Self {
Self(value)
}
}
#[async_trait]
pub trait VoidHook: Send + Sync {
fn name(&self) -> &str;
fn priority(&self) -> HookPriority {
HookPriority::NORMAL
}
async fn on_session_start(&self, _session_id: &str) {}
async fn on_session_end(&self, _session_id: &str) {}
async fn on_llm_input(&self, _messages: &[ChatMessage], _model: &str) {}
async fn on_llm_output(&self, _response: &ChatResponse) {}
async fn on_after_tool_call(&self, _tool: &str, _result: &ToolResult, _duration_ms: u64) {}
async fn on_step_complete(&self, _step: usize, _output: &AgentOutput) {}
async fn on_error(&self, _error: &AgentError) {}
}
#[async_trait]
pub trait ModifyingHook: Send + Sync {
fn name(&self) -> &str;
fn priority(&self) -> HookPriority {
HookPriority::NORMAL
}
async fn before_model_resolve(
&self,
provider: String,
model: String,
) -> HookResult<(String, String)> {
HookResult::Continue((provider, model))
}
async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
HookResult::Continue(prompt)
}
async fn before_llm_call(
&self,
messages: Vec<ChatMessage>,
model: String,
) -> HookResult<(Vec<ChatMessage>, String)> {
HookResult::Continue((messages, model))
}
async fn before_tool_call(&self, name: String, args: Value) -> HookResult<(String, Value)> {
HookResult::Continue((name, args))
}
async fn on_input_received(&self, input: AgentInput) -> HookResult<AgentInput> {
HookResult::Continue(input)
}
async fn on_output_generated(&self, output: AgentOutput) -> HookResult<AgentOutput> {
HookResult::Continue(output)
}
}
#[derive(Default)]
pub struct HookRegistry {
void_hooks: Vec<Arc<dyn VoidHook>>,
modifying_hooks: Vec<Arc<dyn ModifyingHook>>,
}
impl HookRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register_void(&mut self, hook: Arc<dyn VoidHook>) {
self.void_hooks.push(hook);
self.void_hooks
.sort_by_key(|h| std::cmp::Reverse(h.priority()));
}
pub fn register_modifying(&mut self, hook: Arc<dyn ModifyingHook>) {
self.modifying_hooks.push(hook);
self.modifying_hooks
.sort_by_key(|h| std::cmp::Reverse(h.priority()));
}
pub async fn run_void<F, Fut>(&self, f: F)
where
F: Fn(&dyn VoidHook) -> Fut + Send + Sync,
Fut: std::future::Future<Output = ()> + Send,
{
use futures_util::future::join_all;
let futures: Vec<_> = self.void_hooks.iter().map(|hook| f(&**hook)).collect();
join_all(futures).await;
}
pub async fn run_modifying<T, F, Fut>(&self, initial: T, f: F) -> HookResult<T>
where
T: Clone,
F: Fn(&dyn ModifyingHook, T) -> Fut + Send + Sync,
Fut: std::future::Future<Output = HookResult<T>> + Send,
{
let mut current = initial;
for hook in &self.modifying_hooks {
match f(&**hook, current.clone()).await {
HookResult::Continue(v) => current = v,
HookResult::Cancel(msg) => return HookResult::Cancel(msg),
}
}
HookResult::Continue(current)
}
pub fn void_hook_count(&self) -> usize {
self.void_hooks.len()
}
pub fn modifying_hook_count(&self) -> usize {
self.modifying_hooks.len()
}
pub fn clear(&mut self) {
self.void_hooks.clear();
self.modifying_hooks.clear();
}
}
pub struct CombinedHook {
void_hooks: Vec<Arc<dyn VoidHook>>,
modifying_hooks: Vec<Arc<dyn ModifyingHook>>,
}
impl Default for CombinedHook {
fn default() -> Self {
Self::new()
}
}
impl CombinedHook {
pub fn new() -> Self {
Self {
void_hooks: Vec::new(),
modifying_hooks: Vec::new(),
}
}
pub fn add_void(mut self, hook: Arc<dyn VoidHook>) -> Self {
self.void_hooks.push(hook);
self
}
pub fn add_modifying(mut self, hook: Arc<dyn ModifyingHook>) -> Self {
self.modifying_hooks.push(hook);
self
}
pub fn build(self) -> HookRegistry {
let mut registry = HookRegistry::new();
for hook in self.void_hooks {
registry.register_void(hook);
}
for hook in self.modifying_hooks {
registry.register_modifying(hook);
}
registry
}
}
pub struct LoggingVoidHook {
name: String,
priority: HookPriority,
}
impl LoggingVoidHook {
pub fn new() -> Self {
Self {
name: "logging".to_string(),
priority: HookPriority::NORMAL,
}
}
pub fn with_priority(mut self, priority: HookPriority) -> Self {
self.priority = priority;
self
}
}
impl Default for LoggingVoidHook {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl VoidHook for LoggingVoidHook {
fn name(&self) -> &str {
&self.name
}
fn priority(&self) -> HookPriority {
self.priority
}
async fn on_session_start(&self, session_id: &str) {
tracing::info!(session_id, "hook.session.start");
}
async fn on_session_end(&self, session_id: &str) {
tracing::info!(session_id, "hook.session.end");
}
async fn on_llm_input(&self, messages: &[ChatMessage], model: &str) {
tracing::debug!(message_count = messages.len(), model, "hook.llm.input");
}
async fn on_llm_output(&self, response: &ChatResponse) {
tracing::debug!(
content_len = response.message.content.len(),
"hook.llm.output"
);
}
async fn on_after_tool_call(&self, tool: &str, _result: &ToolResult, duration_ms: u64) {
tracing::info!(tool_name = tool, duration_ms, "hook.tool_call.complete");
}
async fn on_error(&self, error: &AgentError) {
tracing::error!(error = %error, "hook.error");
}
}
pub struct ValidationModifyingHook {
name: String,
priority: HookPriority,
max_prompt_length: usize,
}
impl ValidationModifyingHook {
pub fn new() -> Self {
Self {
name: "validation".to_string(),
priority: HookPriority::HIGH, max_prompt_length: 10000,
}
}
pub fn with_max_prompt_length(mut self, max: usize) -> Self {
self.max_prompt_length = max;
self
}
}
impl Default for ValidationModifyingHook {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ModifyingHook for ValidationModifyingHook {
fn name(&self) -> &str {
&self.name
}
fn priority(&self) -> HookPriority {
self.priority
}
async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
if prompt.len() > self.max_prompt_length {
return HookResult::Cancel(format!(
"Prompt 长度 {} 超过最大限制 {}",
prompt.len(),
self.max_prompt_length
));
}
HookResult::Continue(prompt)
}
async fn before_tool_call(&self, name: String, args: Value) -> HookResult<(String, Value)> {
let forbidden_tools = ["rm", "del", "delete"];
if forbidden_tools.contains(&name.as_str()) {
return HookResult::Cancel(format!("工具 '{name}' 被禁止调用"));
}
HookResult::Continue((name, args))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_priority() {
assert!(HookPriority::HIGHEST > HookPriority::HIGH);
assert!(HookPriority::HIGH > HookPriority::NORMAL);
assert!(HookPriority::NORMAL > HookPriority::LOW);
assert!(HookPriority::LOW > HookPriority::LOWEST);
}
#[test]
fn test_hook_result() {
let result: HookResult<i32> = HookResult::Continue(42);
assert!(result.is_continue());
assert!(!result.is_cancel());
assert_eq!(result.into_option(), Some(42));
let result: HookResult<i32> = HookResult::Cancel("error".to_string());
assert!(!result.is_continue());
assert!(result.is_cancel());
assert_eq!(result.into_option(), None);
}
#[test]
fn test_hook_result_map() {
let result: HookResult<i32> = HookResult::Continue(21);
let mapped = result.map(|x| x * 2);
assert!(matches!(mapped, HookResult::Continue(42)));
let result: HookResult<i32> = HookResult::Cancel("error".to_string());
let mapped = result.map(|x| x * 2);
assert!(matches!(mapped, HookResult::Cancel(_)));
}
#[tokio::test]
async fn test_hook_registry_void() {
let mut registry = HookRegistry::new();
registry.register_void(Arc::new(LoggingVoidHook::new()));
assert_eq!(registry.void_hook_count(), 1);
let executed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let executed_clone = executed.clone();
registry
.run_void(move |_hook| {
executed_clone.store(true, std::sync::atomic::Ordering::SeqCst);
async {}
})
.await;
assert!(executed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn test_hook_registry_modifying() {
let mut registry = HookRegistry::new();
registry.register_modifying(Arc::new(ValidationModifyingHook::new()));
#[allow(clippy::unused_async)]
async fn modify_string(s: String) -> HookResult<String> {
HookResult::Continue(s + " modified")
}
let result = registry
.run_modifying("test".to_string(), |_hook, s| modify_string(s))
.await;
assert!(matches!(result, HookResult::Continue(s) if s == "test modified"));
}
#[tokio::test]
async fn test_hook_registry_cancel() {
struct CancelHook;
#[async_trait]
impl ModifyingHook for CancelHook {
fn name(&self) -> &str {
"cancel"
}
async fn before_prompt_build(&self, _prompt: String) -> HookResult<String> {
HookResult::Cancel("test cancel".to_string())
}
}
let hook = CancelHook;
let result = hook.before_prompt_build("test".to_string()).await;
assert!(matches!(result, HookResult::Cancel(msg) if msg == "test cancel"));
let mut registry = HookRegistry::new();
registry.register_modifying(Arc::new(CancelHook));
assert_eq!(registry.modifying_hook_count(), 1);
}
}