ai_lib_rust/plugins/
hooks.rs1use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6use super::base::PluginContext;
7use crate::Result;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum HookType { BeforeRequest, AfterResponse, OnError, OnStreamEvent, OnRetry, OnFallback, OnCacheHit, OnCacheMiss }
11
12#[async_trait]
13pub trait AsyncHook: Send + Sync { async fn call(&self, ctx: &mut PluginContext) -> Result<()>; }
14
15pub struct Hook { pub name: String, pub priority: i32, callback: Arc<dyn AsyncHook> }
16impl Hook {
17 pub fn new<H: AsyncHook + 'static>(name: impl Into<String>, priority: i32, callback: H) -> Self { Self { name: name.into(), priority, callback: Arc::new(callback) } }
18 pub async fn call(&self, ctx: &mut PluginContext) -> Result<()> { self.callback.call(ctx).await }
19}
20
21pub struct FnHook<F> { func: F }
22impl<F> FnHook<F> where F: Fn(&mut PluginContext) -> Result<()> + Send + Sync { pub fn new(func: F) -> Self { Self { func } } }
23#[async_trait]
24impl<F> AsyncHook for FnHook<F> where F: Fn(&mut PluginContext) -> Result<()> + Send + Sync { async fn call(&self, ctx: &mut PluginContext) -> Result<()> { (self.func)(ctx) } }
25
26pub struct HookManager { hooks: RwLock<HashMap<HookType, Vec<Hook>>> }
27impl HookManager {
28 pub fn new() -> Self { Self { hooks: RwLock::new(HashMap::new()) } }
29
30 pub fn register(&self, hook_type: HookType, hook: Hook) {
31 let mut hooks = self.hooks.write().unwrap();
32 let entry = hooks.entry(hook_type).or_insert_with(Vec::new);
33 entry.push(hook);
34 entry.sort_by_key(|h| h.priority);
35 }
36
37 pub fn register_fn<F>(&self, hook_type: HookType, name: impl Into<String>, priority: i32, func: F)
38 where F: Fn(&mut PluginContext) -> Result<()> + Send + Sync + 'static {
39 self.register(hook_type, Hook::new(name, priority, FnHook::new(func)));
40 }
41
42 pub fn unregister(&self, hook_type: HookType, name: &str) -> bool {
43 let mut hooks = self.hooks.write().unwrap();
44 if let Some(entry) = hooks.get_mut(&hook_type) { let len = entry.len(); entry.retain(|h| h.name != name); return entry.len() < len; }
45 false
46 }
47
48 pub async fn trigger(&self, hook_type: HookType, ctx: &mut PluginContext) -> Result<()> {
49 let callbacks: Vec<Arc<dyn AsyncHook>> = { let hooks = self.hooks.read().unwrap(); hooks.get(&hook_type).map(|v| v.iter().map(|h| h.callback.clone()).collect()).unwrap_or_default() };
50 for cb in callbacks { if ctx.should_skip() { break; } cb.call(ctx).await?; }
51 Ok(())
52 }
53
54 pub fn count(&self, hook_type: HookType) -> usize { self.hooks.read().unwrap().get(&hook_type).map(|v| v.len()).unwrap_or(0) }
55 pub fn clear(&self) { self.hooks.write().unwrap().clear(); }
56}
57impl Default for HookManager { fn default() -> Self { Self::new() } }