ai_lib_rust/plugins/
hooks.rs1use super::base::PluginContext;
4use crate::Result;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum HookType {
11 BeforeRequest,
12 AfterResponse,
13 OnError,
14 OnStreamEvent,
15 OnRetry,
16 OnFallback,
17 OnCacheHit,
18 OnCacheMiss,
19}
20
21#[async_trait]
22pub trait AsyncHook: Send + Sync {
23 async fn call(&self, ctx: &mut PluginContext) -> Result<()>;
24}
25
26pub struct Hook {
27 pub name: String,
28 pub priority: i32,
29 callback: Arc<dyn AsyncHook>,
30}
31impl Hook {
32 pub fn new<H: AsyncHook + 'static>(
33 name: impl Into<String>,
34 priority: i32,
35 callback: H,
36 ) -> Self {
37 Self {
38 name: name.into(),
39 priority,
40 callback: Arc::new(callback),
41 }
42 }
43 pub async fn call(&self, ctx: &mut PluginContext) -> Result<()> {
44 self.callback.call(ctx).await
45 }
46}
47
48pub struct FnHook<F> {
49 func: F,
50}
51impl<F> FnHook<F>
52where
53 F: Fn(&mut PluginContext) -> Result<()> + Send + Sync,
54{
55 pub fn new(func: F) -> Self {
56 Self { func }
57 }
58}
59#[async_trait]
60impl<F> AsyncHook for FnHook<F>
61where
62 F: Fn(&mut PluginContext) -> Result<()> + Send + Sync,
63{
64 async fn call(&self, ctx: &mut PluginContext) -> Result<()> {
65 (self.func)(ctx)
66 }
67}
68
69pub struct HookManager {
70 hooks: RwLock<HashMap<HookType, Vec<Hook>>>,
71}
72impl HookManager {
73 pub fn new() -> Self {
74 Self {
75 hooks: RwLock::new(HashMap::new()),
76 }
77 }
78
79 pub fn register(&self, hook_type: HookType, hook: Hook) {
80 let mut hooks = self.hooks.write().unwrap();
81 let entry = hooks.entry(hook_type).or_insert_with(Vec::new);
82 entry.push(hook);
83 entry.sort_by_key(|h| h.priority);
84 }
85
86 pub fn register_fn<F>(
87 &self,
88 hook_type: HookType,
89 name: impl Into<String>,
90 priority: i32,
91 func: F,
92 ) where
93 F: Fn(&mut PluginContext) -> Result<()> + Send + Sync + 'static,
94 {
95 self.register(hook_type, Hook::new(name, priority, FnHook::new(func)));
96 }
97
98 pub fn unregister(&self, hook_type: HookType, name: &str) -> bool {
99 let mut hooks = self.hooks.write().unwrap();
100 if let Some(entry) = hooks.get_mut(&hook_type) {
101 let len = entry.len();
102 entry.retain(|h| h.name != name);
103 return entry.len() < len;
104 }
105 false
106 }
107
108 pub async fn trigger(&self, hook_type: HookType, ctx: &mut PluginContext) -> Result<()> {
109 let callbacks: Vec<Arc<dyn AsyncHook>> = {
110 let hooks = self.hooks.read().unwrap();
111 hooks
112 .get(&hook_type)
113 .map(|v| v.iter().map(|h| h.callback.clone()).collect())
114 .unwrap_or_default()
115 };
116 for cb in callbacks {
117 if ctx.should_skip() {
118 break;
119 }
120 cb.call(ctx).await?;
121 }
122 Ok(())
123 }
124
125 pub fn count(&self, hook_type: HookType) -> usize {
126 self.hooks
127 .read()
128 .unwrap()
129 .get(&hook_type)
130 .map(|v| v.len())
131 .unwrap_or(0)
132 }
133 pub fn clear(&self) {
134 self.hooks.write().unwrap().clear();
135 }
136}
137impl Default for HookManager {
138 fn default() -> Self {
139 Self::new()
140 }
141}