Skip to main content

rucora_core/channel/
hooks.rs

1//! Hook(钩子)优先级系统
2//!
3//! 本模块提供增强的 Hook 系统,支持:
4//! - 优先级排序(priority: i32)
5//! - Void 钩子(并行 fire-and-forget,只读观察)
6//! - Modifying 钩子(按优先级顺序执行,可修改数据或取消操作)
7//!
8//! 参考实现: zeroclaw `HookHandler` trait
9
10use async_trait::async_trait;
11use serde_json::Value;
12use std::sync::Arc;
13
14use crate::{
15    agent::{AgentError, AgentInput, AgentOutput},
16    provider::types::{ChatMessage, ChatResponse},
17    tool::types::ToolResult,
18};
19
20/// Hook 执行结果
21#[derive(Debug, Clone)]
22pub enum HookResult<T> {
23    /// 继续执行,可能包含修改后的数据
24    Continue(T),
25    /// 取消操作,包含原因
26    Cancel(String),
27}
28
29impl<T> HookResult<T> {
30    /// 获取 Continue 中的值,如果是 Cancel 则返回 None
31    pub fn into_option(self) -> Option<T> {
32        match self {
33            HookResult::Continue(v) => Some(v),
34            HookResult::Cancel(_) => None,
35        }
36    }
37
38    /// 检查是否是 Continue
39    pub fn is_continue(&self) -> bool {
40        matches!(self, HookResult::Continue(_))
41    }
42
43    /// 检查是否是 Cancel
44    pub fn is_cancel(&self) -> bool {
45        matches!(self, HookResult::Cancel(_))
46    }
47
48    /// 映射 Continue 值
49    pub fn map<F, U>(self, f: F) -> HookResult<U>
50    where
51        F: FnOnce(T) -> U,
52    {
53        match self {
54            HookResult::Continue(v) => HookResult::Continue(f(v)),
55            HookResult::Cancel(msg) => HookResult::Cancel(msg),
56        }
57    }
58}
59
60/// Hook 优先级
61#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
62pub struct HookPriority(pub i32);
63
64impl HookPriority {
65    /// 最高优先级(最先执行)
66    pub const HIGHEST: Self = Self(i32::MAX);
67    /// 高优先级
68    pub const HIGH: Self = Self(100);
69    /// 默认优先级
70    pub const NORMAL: Self = Self(0);
71    /// 低优先级
72    pub const LOW: Self = Self(-100);
73    /// 最低优先级(最后执行)
74    pub const LOWEST: Self = Self(i32::MIN);
75}
76
77impl Default for HookPriority {
78    fn default() -> Self {
79        Self::NORMAL
80    }
81}
82
83impl From<i32> for HookPriority {
84    fn from(value: i32) -> Self {
85        Self(value)
86    }
87}
88
89/// Void Hook trait(只读观察型钩子)
90///
91/// Void 钩子是 fire-and-forget 类型的钩子,用于观察事件而不修改数据。
92/// 多个 Void 钩子可以并行执行。
93#[async_trait]
94pub trait VoidHook: Send + Sync {
95    /// 钩子名称
96    fn name(&self) -> &str;
97
98    /// 钩子优先级(数值越大优先级越高)
99    fn priority(&self) -> HookPriority {
100        HookPriority::NORMAL
101    }
102
103    /// 会话开始钩子
104    async fn on_session_start(&self, _session_id: &str) {}
105
106    /// 会话结束钩子
107    async fn on_session_end(&self, _session_id: &str) {}
108
109    /// LLM 输入钩子
110    async fn on_llm_input(&self, _messages: &[ChatMessage], _model: &str) {}
111
112    /// LLM 输出钩子
113    async fn on_llm_output(&self, _response: &ChatResponse) {}
114
115    /// 工具调用后钩子
116    async fn on_after_tool_call(&self, _tool: &str, _result: &ToolResult, _duration_ms: u64) {}
117
118    /// Agent 步骤完成钩子
119    async fn on_step_complete(&self, _step: usize, _output: &AgentOutput) {}
120
121    /// 错误钩子
122    async fn on_error(&self, _error: &AgentError) {}
123}
124
125/// Modifying Hook trait(修改型钩子)
126///
127/// Modifying 钩子可以修改数据或取消操作。钩子按优先级顺序执行,
128/// 每个钩子的输出作为下一个钩子的输入。
129#[async_trait]
130pub trait ModifyingHook: Send + Sync {
131    /// 钩子名称
132    fn name(&self) -> &str;
133
134    /// 钩子优先级(数值越大优先级越高)
135    fn priority(&self) -> HookPriority {
136        HookPriority::NORMAL
137    }
138
139    /// 模型解析前钩子
140    ///
141    /// 可以修改 provider 和 model 名称
142    async fn before_model_resolve(
143        &self,
144        provider: String,
145        model: String,
146    ) -> HookResult<(String, String)> {
147        HookResult::Continue((provider, model))
148    }
149
150    /// Prompt 构建前钩子
151    ///
152    /// 可以修改系统 prompt
153    async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
154        HookResult::Continue(prompt)
155    }
156
157    /// LLM 调用前钩子
158    ///
159    /// 可以修改消息列表和模型参数
160    async fn before_llm_call(
161        &self,
162        messages: Vec<ChatMessage>,
163        model: String,
164    ) -> HookResult<(Vec<ChatMessage>, String)> {
165        HookResult::Continue((messages, model))
166    }
167
168    /// 工具调用前钩子
169    ///
170    /// 可以修改工具名称和参数
171    async fn before_tool_call(&self, name: String, args: Value) -> HookResult<(String, Value)> {
172        HookResult::Continue((name, args))
173    }
174
175    /// Agent 输入处理钩子
176    ///
177    /// 可以修改用户输入
178    async fn on_input_received(&self, input: AgentInput) -> HookResult<AgentInput> {
179        HookResult::Continue(input)
180    }
181
182    /// Agent 输出生成钩子
183    ///
184    /// 可以修改输出
185    async fn on_output_generated(&self, output: AgentOutput) -> HookResult<AgentOutput> {
186        HookResult::Continue(output)
187    }
188}
189
190/// 钩子注册表
191#[derive(Default)]
192pub struct HookRegistry {
193    void_hooks: Vec<Arc<dyn VoidHook>>,
194    modifying_hooks: Vec<Arc<dyn ModifyingHook>>,
195}
196
197impl HookRegistry {
198    /// 创建新的钩子注册表
199    pub fn new() -> Self {
200        Self::default()
201    }
202
203    /// 注册 Void 钩子
204    pub fn register_void(&mut self, hook: Arc<dyn VoidHook>) {
205        self.void_hooks.push(hook);
206        // 按优先级排序(高优先级在前)
207        self.void_hooks
208            .sort_by_key(|h| std::cmp::Reverse(h.priority()));
209    }
210
211    /// 注册 Modifying 钩子
212    pub fn register_modifying(&mut self, hook: Arc<dyn ModifyingHook>) {
213        self.modifying_hooks.push(hook);
214        // 按优先级排序(高优先级在前)
215        self.modifying_hooks
216            .sort_by_key(|h| std::cmp::Reverse(h.priority()));
217    }
218
219    /// 执行 Void 钩子(并行)
220    pub async fn run_void<F, Fut>(&self, f: F)
221    where
222        F: Fn(&dyn VoidHook) -> Fut + Send + Sync,
223        Fut: std::future::Future<Output = ()> + Send,
224    {
225        use futures_util::future::join_all;
226
227        let futures: Vec<_> = self.void_hooks.iter().map(|hook| f(&**hook)).collect();
228        join_all(futures).await;
229    }
230
231    /// 执行 Modifying 钩子(顺序,高优先级先执行)
232    pub async fn run_modifying<T, F, Fut>(&self, initial: T, f: F) -> HookResult<T>
233    where
234        T: Clone,
235        F: Fn(&dyn ModifyingHook, T) -> Fut + Send + Sync,
236        Fut: std::future::Future<Output = HookResult<T>> + Send,
237    {
238        let mut current = initial;
239
240        for hook in &self.modifying_hooks {
241            match f(&**hook, current.clone()).await {
242                HookResult::Continue(v) => current = v,
243                HookResult::Cancel(msg) => return HookResult::Cancel(msg),
244            }
245        }
246
247        HookResult::Continue(current)
248    }
249
250    /// 获取 Void 钩子数量
251    pub fn void_hook_count(&self) -> usize {
252        self.void_hooks.len()
253    }
254
255    /// 获取 Modifying 钩子数量
256    pub fn modifying_hook_count(&self) -> usize {
257        self.modifying_hooks.len()
258    }
259
260    /// 清空所有钩子
261    pub fn clear(&mut self) {
262        self.void_hooks.clear();
263        self.modifying_hooks.clear();
264    }
265}
266
267/// 组合钩子(同时实现 VoidHook 和 ModifyingHook)
268pub struct CombinedHook {
269    void_hooks: Vec<Arc<dyn VoidHook>>,
270    modifying_hooks: Vec<Arc<dyn ModifyingHook>>,
271}
272
273impl Default for CombinedHook {
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279impl CombinedHook {
280    /// 创建新的组合钩子
281    pub fn new() -> Self {
282        Self {
283            void_hooks: Vec::new(),
284            modifying_hooks: Vec::new(),
285        }
286    }
287
288    /// 添加 Void 钩子
289    pub fn add_void(mut self, hook: Arc<dyn VoidHook>) -> Self {
290        self.void_hooks.push(hook);
291        self
292    }
293
294    /// 添加 Modifying 钩子
295    pub fn add_modifying(mut self, hook: Arc<dyn ModifyingHook>) -> Self {
296        self.modifying_hooks.push(hook);
297        self
298    }
299
300    /// 构建钩子注册表
301    pub fn build(self) -> HookRegistry {
302        let mut registry = HookRegistry::new();
303        for hook in self.void_hooks {
304            registry.register_void(hook);
305        }
306        for hook in self.modifying_hooks {
307            registry.register_modifying(hook);
308        }
309        registry
310    }
311}
312
313/// 日志 Void Hook(示例实现)
314pub struct LoggingVoidHook {
315    name: String,
316    priority: HookPriority,
317}
318
319impl LoggingVoidHook {
320    /// 创建新的日志钩子
321    pub fn new() -> Self {
322        Self {
323            name: "logging".to_string(),
324            priority: HookPriority::NORMAL,
325        }
326    }
327
328    /// 设置优先级
329    pub fn with_priority(mut self, priority: HookPriority) -> Self {
330        self.priority = priority;
331        self
332    }
333}
334
335impl Default for LoggingVoidHook {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341#[async_trait]
342impl VoidHook for LoggingVoidHook {
343    fn name(&self) -> &str {
344        &self.name
345    }
346
347    fn priority(&self) -> HookPriority {
348        self.priority
349    }
350
351    async fn on_session_start(&self, session_id: &str) {
352        tracing::info!(session_id, "hook.session.start");
353    }
354
355    async fn on_session_end(&self, session_id: &str) {
356        tracing::info!(session_id, "hook.session.end");
357    }
358
359    async fn on_llm_input(&self, messages: &[ChatMessage], model: &str) {
360        tracing::debug!(message_count = messages.len(), model, "hook.llm.input");
361    }
362
363    async fn on_llm_output(&self, response: &ChatResponse) {
364        tracing::debug!(
365            content_len = response.message.content.len(),
366            "hook.llm.output"
367        );
368    }
369
370    async fn on_after_tool_call(&self, tool: &str, _result: &ToolResult, duration_ms: u64) {
371        tracing::info!(tool_name = tool, duration_ms, "hook.tool_call.complete");
372    }
373
374    async fn on_error(&self, error: &AgentError) {
375        tracing::error!(error = %error, "hook.error");
376    }
377}
378
379/// 验证 Modifying Hook(示例实现)
380pub struct ValidationModifyingHook {
381    name: String,
382    priority: HookPriority,
383    max_prompt_length: usize,
384}
385
386impl ValidationModifyingHook {
387    /// 创建新的验证钩子
388    pub fn new() -> Self {
389        Self {
390            name: "validation".to_string(),
391            priority: HookPriority::HIGH, // 高优先级,尽早执行
392            max_prompt_length: 10000,
393        }
394    }
395
396    /// 设置最大 prompt 长度
397    pub fn with_max_prompt_length(mut self, max: usize) -> Self {
398        self.max_prompt_length = max;
399        self
400    }
401}
402
403impl Default for ValidationModifyingHook {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409#[async_trait]
410impl ModifyingHook for ValidationModifyingHook {
411    fn name(&self) -> &str {
412        &self.name
413    }
414
415    fn priority(&self) -> HookPriority {
416        self.priority
417    }
418
419    async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
420        if prompt.len() > self.max_prompt_length {
421            return HookResult::Cancel(format!(
422                "Prompt 长度 {} 超过最大限制 {}",
423                prompt.len(),
424                self.max_prompt_length
425            ));
426        }
427        HookResult::Continue(prompt)
428    }
429
430    async fn before_tool_call(&self, name: String, args: Value) -> HookResult<(String, Value)> {
431        // 示例:禁止调用某些危险工具
432        let forbidden_tools = ["rm", "del", "delete"];
433        if forbidden_tools.contains(&name.as_str()) {
434            return HookResult::Cancel(format!("工具 '{name}' 被禁止调用"));
435        }
436        HookResult::Continue((name, args))
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_hook_priority() {
446        assert!(HookPriority::HIGHEST > HookPriority::HIGH);
447        assert!(HookPriority::HIGH > HookPriority::NORMAL);
448        assert!(HookPriority::NORMAL > HookPriority::LOW);
449        assert!(HookPriority::LOW > HookPriority::LOWEST);
450    }
451
452    #[test]
453    fn test_hook_result() {
454        let result: HookResult<i32> = HookResult::Continue(42);
455        assert!(result.is_continue());
456        assert!(!result.is_cancel());
457        assert_eq!(result.into_option(), Some(42));
458
459        let result: HookResult<i32> = HookResult::Cancel("error".to_string());
460        assert!(!result.is_continue());
461        assert!(result.is_cancel());
462        assert_eq!(result.into_option(), None);
463    }
464
465    #[test]
466    fn test_hook_result_map() {
467        let result: HookResult<i32> = HookResult::Continue(21);
468        let mapped = result.map(|x| x * 2);
469        assert!(matches!(mapped, HookResult::Continue(42)));
470
471        let result: HookResult<i32> = HookResult::Cancel("error".to_string());
472        let mapped = result.map(|x| x * 2);
473        assert!(matches!(mapped, HookResult::Cancel(_)));
474    }
475
476    #[tokio::test]
477    async fn test_hook_registry_void() {
478        let mut registry = HookRegistry::new();
479        registry.register_void(Arc::new(LoggingVoidHook::new()));
480
481        assert_eq!(registry.void_hook_count(), 1);
482
483        // 使用 AtomicBool 来在闭包中共享状态
484        let executed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
485        let executed_clone = executed.clone();
486        registry
487            .run_void(move |_hook| {
488                executed_clone.store(true, std::sync::atomic::Ordering::SeqCst);
489                async {}
490            })
491            .await;
492
493        assert!(executed.load(std::sync::atomic::Ordering::SeqCst));
494    }
495
496    #[tokio::test]
497    async fn test_hook_registry_modifying() {
498        let mut registry = HookRegistry::new();
499        registry.register_modifying(Arc::new(ValidationModifyingHook::new()));
500
501        // 使用一个静态的修改函数来避免生命周期问题
502        #[allow(clippy::unused_async)]
503        async fn modify_string(s: String) -> HookResult<String> {
504            HookResult::Continue(s + " modified")
505        }
506
507        let result = registry
508            .run_modifying("test".to_string(), |_hook, s| modify_string(s))
509            .await;
510
511        assert!(matches!(result, HookResult::Continue(s) if s == "test modified"));
512    }
513
514    #[tokio::test]
515    async fn test_hook_registry_cancel() {
516        struct CancelHook;
517
518        #[async_trait]
519        impl ModifyingHook for CancelHook {
520            fn name(&self) -> &str {
521                "cancel"
522            }
523
524            async fn before_prompt_build(&self, _prompt: String) -> HookResult<String> {
525                HookResult::Cancel("test cancel".to_string())
526            }
527        }
528
529        // 直接测试 CancelHook 的行为,不通过 registry
530        let hook = CancelHook;
531        let result = hook.before_prompt_build("test".to_string()).await;
532
533        assert!(matches!(result, HookResult::Cancel(msg) if msg == "test cancel"));
534
535        // 测试 registry 可以正常注册
536        let mut registry = HookRegistry::new();
537        registry.register_modifying(Arc::new(CancelHook));
538        assert_eq!(registry.modifying_hook_count(), 1);
539    }
540}