1use async_trait::async_trait;
11use serde_json::Value;
12use std::sync::Arc;
13
14use crate::{
15 agent::{AgentError, AgentInput, AgentOutput},
16 provider::types::{ChatMessage, ChatResponse},
17 tool::types::ToolResult,
18};
19
20#[derive(Debug, Clone)]
22pub enum HookResult<T> {
23 Continue(T),
25 Cancel(String),
27}
28
29impl<T> HookResult<T> {
30 pub fn into_option(self) -> Option<T> {
32 match self {
33 HookResult::Continue(v) => Some(v),
34 HookResult::Cancel(_) => None,
35 }
36 }
37
38 pub fn is_continue(&self) -> bool {
40 matches!(self, HookResult::Continue(_))
41 }
42
43 pub fn is_cancel(&self) -> bool {
45 matches!(self, HookResult::Cancel(_))
46 }
47
48 pub fn map<F, U>(self, f: F) -> HookResult<U>
50 where
51 F: FnOnce(T) -> U,
52 {
53 match self {
54 HookResult::Continue(v) => HookResult::Continue(f(v)),
55 HookResult::Cancel(msg) => HookResult::Cancel(msg),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
62pub struct HookPriority(pub i32);
63
64impl HookPriority {
65 pub const HIGHEST: Self = Self(i32::MAX);
67 pub const HIGH: Self = Self(100);
69 pub const NORMAL: Self = Self(0);
71 pub const LOW: Self = Self(-100);
73 pub const LOWEST: Self = Self(i32::MIN);
75}
76
77impl Default for HookPriority {
78 fn default() -> Self {
79 Self::NORMAL
80 }
81}
82
83impl From<i32> for HookPriority {
84 fn from(value: i32) -> Self {
85 Self(value)
86 }
87}
88
89#[async_trait]
94pub trait VoidHook: Send + Sync {
95 fn name(&self) -> &str;
97
98 fn priority(&self) -> HookPriority {
100 HookPriority::NORMAL
101 }
102
103 async fn on_session_start(&self, _session_id: &str) {}
105
106 async fn on_session_end(&self, _session_id: &str) {}
108
109 async fn on_llm_input(&self, _messages: &[ChatMessage], _model: &str) {}
111
112 async fn on_llm_output(&self, _response: &ChatResponse) {}
114
115 async fn on_after_tool_call(&self, _tool: &str, _result: &ToolResult, _duration_ms: u64) {}
117
118 async fn on_step_complete(&self, _step: usize, _output: &AgentOutput) {}
120
121 async fn on_error(&self, _error: &AgentError) {}
123}
124
125#[async_trait]
130pub trait ModifyingHook: Send + Sync {
131 fn name(&self) -> &str;
133
134 fn priority(&self) -> HookPriority {
136 HookPriority::NORMAL
137 }
138
139 async fn before_model_resolve(
143 &self,
144 provider: String,
145 model: String,
146 ) -> HookResult<(String, String)> {
147 HookResult::Continue((provider, model))
148 }
149
150 async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
154 HookResult::Continue(prompt)
155 }
156
157 async fn before_llm_call(
161 &self,
162 messages: Vec<ChatMessage>,
163 model: String,
164 ) -> HookResult<(Vec<ChatMessage>, String)> {
165 HookResult::Continue((messages, model))
166 }
167
168 async fn before_tool_call(&self, name: String, args: Value) -> HookResult<(String, Value)> {
172 HookResult::Continue((name, args))
173 }
174
175 async fn on_input_received(&self, input: AgentInput) -> HookResult<AgentInput> {
179 HookResult::Continue(input)
180 }
181
182 async fn on_output_generated(&self, output: AgentOutput) -> HookResult<AgentOutput> {
186 HookResult::Continue(output)
187 }
188}
189
190#[derive(Default)]
192pub struct HookRegistry {
193 void_hooks: Vec<Arc<dyn VoidHook>>,
194 modifying_hooks: Vec<Arc<dyn ModifyingHook>>,
195}
196
197impl HookRegistry {
198 pub fn new() -> Self {
200 Self::default()
201 }
202
203 pub fn register_void(&mut self, hook: Arc<dyn VoidHook>) {
205 self.void_hooks.push(hook);
206 self.void_hooks
208 .sort_by_key(|h| std::cmp::Reverse(h.priority()));
209 }
210
211 pub fn register_modifying(&mut self, hook: Arc<dyn ModifyingHook>) {
213 self.modifying_hooks.push(hook);
214 self.modifying_hooks
216 .sort_by_key(|h| std::cmp::Reverse(h.priority()));
217 }
218
219 pub async fn run_void<F, Fut>(&self, f: F)
221 where
222 F: Fn(&dyn VoidHook) -> Fut + Send + Sync,
223 Fut: std::future::Future<Output = ()> + Send,
224 {
225 use futures_util::future::join_all;
226
227 let futures: Vec<_> = self.void_hooks.iter().map(|hook| f(&**hook)).collect();
228 join_all(futures).await;
229 }
230
231 pub async fn run_modifying<T, F, Fut>(&self, initial: T, f: F) -> HookResult<T>
233 where
234 T: Clone,
235 F: Fn(&dyn ModifyingHook, T) -> Fut + Send + Sync,
236 Fut: std::future::Future<Output = HookResult<T>> + Send,
237 {
238 let mut current = initial;
239
240 for hook in &self.modifying_hooks {
241 match f(&**hook, current.clone()).await {
242 HookResult::Continue(v) => current = v,
243 HookResult::Cancel(msg) => return HookResult::Cancel(msg),
244 }
245 }
246
247 HookResult::Continue(current)
248 }
249
250 pub fn void_hook_count(&self) -> usize {
252 self.void_hooks.len()
253 }
254
255 pub fn modifying_hook_count(&self) -> usize {
257 self.modifying_hooks.len()
258 }
259
260 pub fn clear(&mut self) {
262 self.void_hooks.clear();
263 self.modifying_hooks.clear();
264 }
265}
266
267pub struct CombinedHook {
269 void_hooks: Vec<Arc<dyn VoidHook>>,
270 modifying_hooks: Vec<Arc<dyn ModifyingHook>>,
271}
272
273impl Default for CombinedHook {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279impl CombinedHook {
280 pub fn new() -> Self {
282 Self {
283 void_hooks: Vec::new(),
284 modifying_hooks: Vec::new(),
285 }
286 }
287
288 pub fn add_void(mut self, hook: Arc<dyn VoidHook>) -> Self {
290 self.void_hooks.push(hook);
291 self
292 }
293
294 pub fn add_modifying(mut self, hook: Arc<dyn ModifyingHook>) -> Self {
296 self.modifying_hooks.push(hook);
297 self
298 }
299
300 pub fn build(self) -> HookRegistry {
302 let mut registry = HookRegistry::new();
303 for hook in self.void_hooks {
304 registry.register_void(hook);
305 }
306 for hook in self.modifying_hooks {
307 registry.register_modifying(hook);
308 }
309 registry
310 }
311}
312
313pub struct LoggingVoidHook {
315 name: String,
316 priority: HookPriority,
317}
318
319impl LoggingVoidHook {
320 pub fn new() -> Self {
322 Self {
323 name: "logging".to_string(),
324 priority: HookPriority::NORMAL,
325 }
326 }
327
328 pub fn with_priority(mut self, priority: HookPriority) -> Self {
330 self.priority = priority;
331 self
332 }
333}
334
335impl Default for LoggingVoidHook {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341#[async_trait]
342impl VoidHook for LoggingVoidHook {
343 fn name(&self) -> &str {
344 &self.name
345 }
346
347 fn priority(&self) -> HookPriority {
348 self.priority
349 }
350
351 async fn on_session_start(&self, session_id: &str) {
352 tracing::info!(session_id, "hook.session.start");
353 }
354
355 async fn on_session_end(&self, session_id: &str) {
356 tracing::info!(session_id, "hook.session.end");
357 }
358
359 async fn on_llm_input(&self, messages: &[ChatMessage], model: &str) {
360 tracing::debug!(message_count = messages.len(), model, "hook.llm.input");
361 }
362
363 async fn on_llm_output(&self, response: &ChatResponse) {
364 tracing::debug!(
365 content_len = response.message.content.len(),
366 "hook.llm.output"
367 );
368 }
369
370 async fn on_after_tool_call(&self, tool: &str, _result: &ToolResult, duration_ms: u64) {
371 tracing::info!(tool_name = tool, duration_ms, "hook.tool_call.complete");
372 }
373
374 async fn on_error(&self, error: &AgentError) {
375 tracing::error!(error = %error, "hook.error");
376 }
377}
378
379pub struct ValidationModifyingHook {
381 name: String,
382 priority: HookPriority,
383 max_prompt_length: usize,
384}
385
386impl ValidationModifyingHook {
387 pub fn new() -> Self {
389 Self {
390 name: "validation".to_string(),
391 priority: HookPriority::HIGH, max_prompt_length: 10000,
393 }
394 }
395
396 pub fn with_max_prompt_length(mut self, max: usize) -> Self {
398 self.max_prompt_length = max;
399 self
400 }
401}
402
403impl Default for ValidationModifyingHook {
404 fn default() -> Self {
405 Self::new()
406 }
407}
408
409#[async_trait]
410impl ModifyingHook for ValidationModifyingHook {
411 fn name(&self) -> &str {
412 &self.name
413 }
414
415 fn priority(&self) -> HookPriority {
416 self.priority
417 }
418
419 async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
420 if prompt.len() > self.max_prompt_length {
421 return HookResult::Cancel(format!(
422 "Prompt 长度 {} 超过最大限制 {}",
423 prompt.len(),
424 self.max_prompt_length
425 ));
426 }
427 HookResult::Continue(prompt)
428 }
429
430 async fn before_tool_call(&self, name: String, args: Value) -> HookResult<(String, Value)> {
431 let forbidden_tools = ["rm", "del", "delete"];
433 if forbidden_tools.contains(&name.as_str()) {
434 return HookResult::Cancel(format!("工具 '{name}' 被禁止调用"));
435 }
436 HookResult::Continue((name, args))
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_hook_priority() {
446 assert!(HookPriority::HIGHEST > HookPriority::HIGH);
447 assert!(HookPriority::HIGH > HookPriority::NORMAL);
448 assert!(HookPriority::NORMAL > HookPriority::LOW);
449 assert!(HookPriority::LOW > HookPriority::LOWEST);
450 }
451
452 #[test]
453 fn test_hook_result() {
454 let result: HookResult<i32> = HookResult::Continue(42);
455 assert!(result.is_continue());
456 assert!(!result.is_cancel());
457 assert_eq!(result.into_option(), Some(42));
458
459 let result: HookResult<i32> = HookResult::Cancel("error".to_string());
460 assert!(!result.is_continue());
461 assert!(result.is_cancel());
462 assert_eq!(result.into_option(), None);
463 }
464
465 #[test]
466 fn test_hook_result_map() {
467 let result: HookResult<i32> = HookResult::Continue(21);
468 let mapped = result.map(|x| x * 2);
469 assert!(matches!(mapped, HookResult::Continue(42)));
470
471 let result: HookResult<i32> = HookResult::Cancel("error".to_string());
472 let mapped = result.map(|x| x * 2);
473 assert!(matches!(mapped, HookResult::Cancel(_)));
474 }
475
476 #[tokio::test]
477 async fn test_hook_registry_void() {
478 let mut registry = HookRegistry::new();
479 registry.register_void(Arc::new(LoggingVoidHook::new()));
480
481 assert_eq!(registry.void_hook_count(), 1);
482
483 let executed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
485 let executed_clone = executed.clone();
486 registry
487 .run_void(move |_hook| {
488 executed_clone.store(true, std::sync::atomic::Ordering::SeqCst);
489 async {}
490 })
491 .await;
492
493 assert!(executed.load(std::sync::atomic::Ordering::SeqCst));
494 }
495
496 #[tokio::test]
497 async fn test_hook_registry_modifying() {
498 let mut registry = HookRegistry::new();
499 registry.register_modifying(Arc::new(ValidationModifyingHook::new()));
500
501 #[allow(clippy::unused_async)]
503 async fn modify_string(s: String) -> HookResult<String> {
504 HookResult::Continue(s + " modified")
505 }
506
507 let result = registry
508 .run_modifying("test".to_string(), |_hook, s| modify_string(s))
509 .await;
510
511 assert!(matches!(result, HookResult::Continue(s) if s == "test modified"));
512 }
513
514 #[tokio::test]
515 async fn test_hook_registry_cancel() {
516 struct CancelHook;
517
518 #[async_trait]
519 impl ModifyingHook for CancelHook {
520 fn name(&self) -> &str {
521 "cancel"
522 }
523
524 async fn before_prompt_build(&self, _prompt: String) -> HookResult<String> {
525 HookResult::Cancel("test cancel".to_string())
526 }
527 }
528
529 let hook = CancelHook;
531 let result = hook.before_prompt_build("test".to_string()).await;
532
533 assert!(matches!(result, HookResult::Cancel(msg) if msg == "test cancel"));
534
535 let mut registry = HookRegistry::new();
537 registry.register_modifying(Arc::new(CancelHook));
538 assert_eq!(registry.modifying_hook_count(), 1);
539 }
540}