Skip to main content

ai_lib_rust/plugins/
hooks.rs

1//! Hook system.
2
3use super::base::PluginContext;
4use crate::Result;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum HookType {
11    BeforeRequest,
12    AfterResponse,
13    OnError,
14    OnStreamEvent,
15    OnRetry,
16    OnFallback,
17    OnCacheHit,
18    OnCacheMiss,
19}
20
21#[async_trait]
22pub trait AsyncHook: Send + Sync {
23    async fn call(&self, ctx: &mut PluginContext) -> Result<()>;
24}
25
26pub struct Hook {
27    pub name: String,
28    pub priority: i32,
29    callback: Arc<dyn AsyncHook>,
30}
31impl Hook {
32    pub fn new<H: AsyncHook + 'static>(
33        name: impl Into<String>,
34        priority: i32,
35        callback: H,
36    ) -> Self {
37        Self {
38            name: name.into(),
39            priority,
40            callback: Arc::new(callback),
41        }
42    }
43    pub async fn call(&self, ctx: &mut PluginContext) -> Result<()> {
44        self.callback.call(ctx).await
45    }
46}
47
48pub struct FnHook<F> {
49    func: F,
50}
51impl<F> FnHook<F>
52where
53    F: Fn(&mut PluginContext) -> Result<()> + Send + Sync,
54{
55    pub fn new(func: F) -> Self {
56        Self { func }
57    }
58}
59#[async_trait]
60impl<F> AsyncHook for FnHook<F>
61where
62    F: Fn(&mut PluginContext) -> Result<()> + Send + Sync,
63{
64    async fn call(&self, ctx: &mut PluginContext) -> Result<()> {
65        (self.func)(ctx)
66    }
67}
68
69pub struct HookManager {
70    hooks: RwLock<HashMap<HookType, Vec<Hook>>>,
71}
72impl HookManager {
73    pub fn new() -> Self {
74        Self {
75            hooks: RwLock::new(HashMap::new()),
76        }
77    }
78
79    pub fn register(&self, hook_type: HookType, hook: Hook) {
80        let mut hooks = self.hooks.write().unwrap();
81        let entry = hooks.entry(hook_type).or_insert_with(Vec::new);
82        entry.push(hook);
83        entry.sort_by_key(|h| h.priority);
84    }
85
86    pub fn register_fn<F>(
87        &self,
88        hook_type: HookType,
89        name: impl Into<String>,
90        priority: i32,
91        func: F,
92    ) where
93        F: Fn(&mut PluginContext) -> Result<()> + Send + Sync + 'static,
94    {
95        self.register(hook_type, Hook::new(name, priority, FnHook::new(func)));
96    }
97
98    pub fn unregister(&self, hook_type: HookType, name: &str) -> bool {
99        let mut hooks = self.hooks.write().unwrap();
100        if let Some(entry) = hooks.get_mut(&hook_type) {
101            let len = entry.len();
102            entry.retain(|h| h.name != name);
103            return entry.len() < len;
104        }
105        false
106    }
107
108    pub async fn trigger(&self, hook_type: HookType, ctx: &mut PluginContext) -> Result<()> {
109        let callbacks: Vec<Arc<dyn AsyncHook>> = {
110            let hooks = self.hooks.read().unwrap();
111            hooks
112                .get(&hook_type)
113                .map(|v| v.iter().map(|h| h.callback.clone()).collect())
114                .unwrap_or_default()
115        };
116        for cb in callbacks {
117            if ctx.should_skip() {
118                break;
119            }
120            cb.call(ctx).await?;
121        }
122        Ok(())
123    }
124
125    pub fn count(&self, hook_type: HookType) -> usize {
126        self.hooks
127            .read()
128            .unwrap()
129            .get(&hook_type)
130            .map(|v| v.len())
131            .unwrap_or(0)
132    }
133    pub fn clear(&self) {
134        self.hooks.write().unwrap().clear();
135    }
136}
137impl Default for HookManager {
138    fn default() -> Self {
139        Self::new()
140    }
141}