Skip to main content

sh_layer2/
hook_system.rs

1//! # Hook System
2//!
3//! 生命周期钩子系统,用于在关键事件点注入自定义逻辑。
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::types::{HookEvent, Layer2Result, SessionId};
11
12/// Hook 回调函数类型
13pub type HookCallback = Arc<dyn Fn(&HookContext) -> Layer2Result<()> + Send + Sync>;
14
15/// Hook 上下文
16#[derive(Debug, Clone)]
17pub struct HookContext {
18    pub session_id: SessionId,
19    pub event: HookEvent,
20    pub timestamp: chrono::DateTime<chrono::Utc>,
21    pub data: serde_json::Value,
22    pub metadata: HashMap<String, String>,
23}
24
25impl HookContext {
26    pub fn new(session_id: SessionId, event: HookEvent) -> Self {
27        Self {
28            session_id,
29            event,
30            timestamp: chrono::Utc::now(),
31            data: serde_json::Value::Null,
32            metadata: HashMap::new(),
33        }
34    }
35
36    pub fn with_data(mut self, data: serde_json::Value) -> Self {
37        self.data = data;
38        self
39    }
40
41    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
42        self.metadata.insert(key.to_string(), value.to_string());
43        self
44    }
45}
46
47/// Hook 系统接口
48#[async_trait]
49pub trait HookSystemTrait: Send + Sync {
50    /// 注册前置钩子
51    fn on_before(&self, event: HookEvent, callback: HookCallback);
52
53    /// 注册后置钩子
54    fn on_after(&self, event: HookEvent, callback: HookCallback);
55
56    /// 触发钩子
57    async fn trigger(&self, context: &HookContext) -> Layer2Result<()>;
58
59    /// 移除钩子
60    fn remove(&self, event: HookEvent, is_before: bool);
61
62    /// 清除所有钩子
63    fn clear(&self);
64
65    /// 获取钩子数量
66    fn count(&self, event: HookEvent) -> (usize, usize);
67}
68
69/// Hook 注册表
70type HookRegistry = HashMap<HookEvent, Vec<HookCallback>>;
71
72/// Hook 系统实现
73pub struct HookSystem {
74    before_hooks: RwLock<HookRegistry>,
75    after_hooks: RwLock<HookRegistry>,
76}
77
78impl HookSystem {
79    pub fn new() -> Self {
80        Self {
81            before_hooks: RwLock::new(HashMap::new()),
82            after_hooks: RwLock::new(HashMap::new()),
83        }
84    }
85}
86
87impl Default for HookSystem {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93#[async_trait]
94impl HookSystemTrait for HookSystem {
95    fn on_before(&self, event: HookEvent, callback: HookCallback) {
96        let mut hooks = self.before_hooks.write();
97        hooks.entry(event).or_default().push(callback);
98    }
99
100    fn on_after(&self, event: HookEvent, callback: HookCallback) {
101        let mut hooks = self.after_hooks.write();
102        hooks.entry(event).or_default().push(callback);
103    }
104
105    async fn trigger(&self, context: &HookContext) -> Layer2Result<()> {
106        // 执行前置钩子
107        {
108            let hooks = self.before_hooks.read();
109            if let Some(callbacks) = hooks.get(&context.event) {
110                for callback in callbacks {
111                    callback(context)?;
112                }
113            }
114        }
115
116        // 执行后置钩子
117        {
118            let hooks = self.after_hooks.read();
119            if let Some(callbacks) = hooks.get(&context.event) {
120                for callback in callbacks {
121                    callback(context)?;
122                }
123            }
124        }
125
126        Ok(())
127    }
128
129    fn remove(&self, event: HookEvent, is_before: bool) {
130        let hooks = if is_before {
131            &self.before_hooks
132        } else {
133            &self.after_hooks
134        };
135
136        let mut hooks = hooks.write();
137        hooks.remove(&event);
138    }
139
140    fn clear(&self) {
141        self.before_hooks.write().clear();
142        self.after_hooks.write().clear();
143    }
144
145    fn count(&self, event: HookEvent) -> (usize, usize) {
146        let before = self
147            .before_hooks
148            .read()
149            .get(&event)
150            .map(|v| v.len())
151            .unwrap_or(0);
152        let after = self
153            .after_hooks
154            .read()
155            .get(&event)
156            .map(|v| v.len())
157            .unwrap_or(0);
158        (before, after)
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_hook_system_creation() {
168        let hooks = HookSystem::new();
169        let (before, after) = hooks.count(HookEvent::BeforeAgentStart);
170        assert_eq!(before, 0);
171        assert_eq!(after, 0);
172    }
173
174    #[test]
175    fn test_hook_registration() {
176        let hooks = HookSystem::new();
177        let callback: HookCallback = Arc::new(|_| Ok(()));
178
179        hooks.on_before(HookEvent::BeforeAgentStart, callback);
180
181        let (before, _) = hooks.count(HookEvent::BeforeAgentStart);
182        assert_eq!(before, 1);
183    }
184}