1use super::{Plugin, PluginContext, PluginError};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tracing::{debug, warn};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum HookPoint {
16 BeforeCreateNode,
18 AfterCreateNode,
20 BeforeCreateSession,
22 AfterCreateSession,
24 BeforeQuery,
26 AfterQuery,
28 BeforeCreateEdge,
30 AfterCreateEdge,
32 BeforeUpdateNode,
34 AfterUpdateNode,
36 BeforeDeleteNode,
38 AfterDeleteNode,
40 BeforeDeleteSession,
42 AfterDeleteSession,
44}
45
46impl HookPoint {
47 pub fn as_str(&self) -> &'static str {
49 match self {
50 Self::BeforeCreateNode => "before_create_node",
51 Self::AfterCreateNode => "after_create_node",
52 Self::BeforeCreateSession => "before_create_session",
53 Self::AfterCreateSession => "after_create_session",
54 Self::BeforeQuery => "before_query",
55 Self::AfterQuery => "after_query",
56 Self::BeforeCreateEdge => "before_create_edge",
57 Self::AfterCreateEdge => "after_create_edge",
58 Self::BeforeUpdateNode => "before_update_node",
59 Self::AfterUpdateNode => "after_update_node",
60 Self::BeforeDeleteNode => "before_delete_node",
61 Self::AfterDeleteNode => "after_delete_node",
62 Self::BeforeDeleteSession => "before_delete_session",
63 Self::AfterDeleteSession => "after_delete_session",
64 }
65 }
66
67 pub fn is_before(&self) -> bool {
69 matches!(
70 self,
71 Self::BeforeCreateNode
72 | Self::BeforeCreateSession
73 | Self::BeforeQuery
74 | Self::BeforeCreateEdge
75 | Self::BeforeUpdateNode
76 | Self::BeforeDeleteNode
77 | Self::BeforeDeleteSession
78 )
79 }
80
81 pub fn is_after(&self) -> bool {
83 !self.is_before()
84 }
85
86 pub fn all() -> Vec<Self> {
88 vec![
89 Self::BeforeCreateNode,
90 Self::AfterCreateNode,
91 Self::BeforeCreateSession,
92 Self::AfterCreateSession,
93 Self::BeforeQuery,
94 Self::AfterQuery,
95 Self::BeforeCreateEdge,
96 Self::AfterCreateEdge,
97 Self::BeforeUpdateNode,
98 Self::AfterUpdateNode,
99 Self::BeforeDeleteNode,
100 Self::AfterDeleteNode,
101 Self::BeforeDeleteSession,
102 Self::AfterDeleteSession,
103 ]
104 }
105}
106
107impl std::fmt::Display for HookPoint {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 write!(f, "{}", self.as_str())
110 }
111}
112
113pub struct HookRegistry {
118 hooks: HashMap<HookPoint, Vec<Arc<dyn Plugin>>>,
119}
120
121impl HookRegistry {
122 pub fn new() -> Self {
124 Self {
125 hooks: HashMap::new(),
126 }
127 }
128
129 pub fn register_hook(&mut self, hook: HookPoint, plugin: Arc<dyn Plugin>) {
131 self.hooks.entry(hook).or_default().push(plugin);
132 }
133
134 pub fn unregister_hook(&mut self, hook: HookPoint, plugin_name: &str) {
136 if let Some(plugins) = self.hooks.get_mut(&hook) {
137 plugins.retain(|p| p.metadata().name != plugin_name);
138 }
139 }
140
141 pub fn unregister_plugin(&mut self, plugin_name: &str) {
143 for plugins in self.hooks.values_mut() {
144 plugins.retain(|p| p.metadata().name != plugin_name);
145 }
146 }
147
148 pub fn get_plugins(&self, hook: HookPoint) -> Vec<Arc<dyn Plugin>> {
150 self.hooks.get(&hook).cloned().unwrap_or_default()
151 }
152
153 pub fn count_plugins(&self, hook: HookPoint) -> usize {
155 self.hooks.get(&hook).map(Vec::len).unwrap_or(0)
156 }
157
158 pub fn clear(&mut self) {
160 self.hooks.clear();
161 }
162
163 pub fn stats(&self) -> HashMap<HookPoint, usize> {
165 self.hooks
166 .iter()
167 .map(|(hook, plugins)| (*hook, plugins.len()))
168 .collect()
169 }
170}
171
172impl Default for HookRegistry {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178pub struct HookExecutor {
183 fail_fast: bool,
185 collect_metrics: bool,
187}
188
189impl HookExecutor {
190 pub fn new() -> Self {
192 Self {
193 fail_fast: true,
194 collect_metrics: false,
195 }
196 }
197
198 pub fn without_fail_fast() -> Self {
204 Self {
205 fail_fast: false,
206 collect_metrics: false,
207 }
208 }
209
210 pub fn with_metrics(mut self) -> Self {
212 self.collect_metrics = true;
213 self
214 }
215
216 pub async fn execute_before(
221 &self,
222 hook: HookPoint,
223 plugins: &[Arc<dyn Plugin>],
224 context: &PluginContext,
225 ) -> Result<(), PluginError> {
226 debug!("Executing {} with {} plugins", hook, plugins.len());
227
228 let mut errors = Vec::new();
229
230 for plugin in plugins {
231 let plugin_name = &plugin.metadata().name;
232 debug!("Executing hook {} for plugin {}", hook, plugin_name);
233
234 let start = std::time::Instant::now();
235
236 match plugin.before_hook(hook.as_str(), context).await {
237 Ok(()) => {
238 if self.collect_metrics {
239 let duration = start.elapsed();
240 debug!(
241 "Plugin {} completed {} in {:?}",
242 plugin_name, hook, duration
243 );
244 }
245 }
246 Err(e) => {
247 warn!("Plugin {} failed on {}: {}", plugin_name, hook, e);
248
249 if self.fail_fast {
250 return Err(e);
251 }
252 errors.push((plugin_name.clone(), e));
253 }
254 }
255 }
256
257 if !errors.is_empty() {
258 let error_msg = errors
259 .iter()
260 .map(|(name, e)| format!("{}: {}", name, e))
261 .collect::<Vec<_>>()
262 .join("; ");
263
264 return Err(PluginError::HookFailed(format!(
265 "Multiple plugins failed: {}",
266 error_msg
267 )));
268 }
269
270 Ok(())
271 }
272
273 pub async fn execute_after(
278 &self,
279 hook: HookPoint,
280 plugins: &[Arc<dyn Plugin>],
281 context: &PluginContext,
282 ) -> Result<(), PluginError> {
283 debug!("Executing {} with {} plugins", hook, plugins.len());
284
285 for plugin in plugins {
286 let plugin_name = &plugin.metadata().name;
287 debug!("Executing hook {} for plugin {}", hook, plugin_name);
288
289 let start = std::time::Instant::now();
290
291 match plugin.after_hook(hook.as_str(), context).await {
292 Ok(()) => {
293 if self.collect_metrics {
294 let duration = start.elapsed();
295 debug!(
296 "Plugin {} completed {} in {:?}",
297 plugin_name, hook, duration
298 );
299 }
300 }
301 Err(e) => {
302 warn!(
304 "Plugin {} failed on after hook {}: {}",
305 plugin_name, hook, e
306 );
307 }
308 }
309 }
310
311 Ok(())
312 }
313
314 pub async fn execute(
319 &self,
320 hook: HookPoint,
321 plugins: &[Arc<dyn Plugin>],
322 context: &PluginContext,
323 ) -> Result<(), PluginError> {
324 if hook.is_before() {
325 self.execute_before(hook, plugins, context).await
326 } else {
327 self.execute_after(hook, plugins, context).await
328 }
329 }
330}
331
332impl Default for HookExecutor {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338#[derive(Debug)]
340pub struct HookExecutionResult {
341 pub hook: HookPoint,
343 pub plugins_executed: usize,
345 pub total_duration: std::time::Duration,
347 pub plugin_durations: HashMap<String, std::time::Duration>,
349 pub errors: Vec<(String, String)>,
351}
352
353impl HookExecutionResult {
354 pub fn is_success(&self) -> bool {
356 self.errors.is_empty()
357 }
358
359 pub fn average_duration(&self) -> std::time::Duration {
361 if self.plugins_executed == 0 {
362 return std::time::Duration::ZERO;
363 }
364 self.total_duration / self.plugins_executed as u32
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use crate::plugin::{PluginBuilder, PluginMetadata};
372 use async_trait::async_trait;
373
374 struct MockPlugin {
375 metadata: PluginMetadata,
376 should_fail: bool,
377 }
378
379 impl MockPlugin {
380 fn new(name: &str, should_fail: bool) -> Self {
381 let metadata = PluginBuilder::new(name, "1.0.0")
382 .author("Test")
383 .description("Test plugin")
384 .build();
385 Self {
386 metadata,
387 should_fail,
388 }
389 }
390 }
391
392 #[async_trait]
393 impl Plugin for MockPlugin {
394 fn metadata(&self) -> &PluginMetadata {
395 &self.metadata
396 }
397
398 async fn before_create_node(&self, _context: &PluginContext) -> Result<(), PluginError> {
399 if self.should_fail {
400 Err(PluginError::HookFailed("Test failure".to_string()))
401 } else {
402 Ok(())
403 }
404 }
405 }
406
407 #[test]
408 fn test_hook_point_as_str() {
409 assert_eq!(HookPoint::BeforeCreateNode.as_str(), "before_create_node");
410 assert_eq!(HookPoint::AfterCreateNode.as_str(), "after_create_node");
411 }
412
413 #[test]
414 fn test_hook_point_is_before() {
415 assert!(HookPoint::BeforeCreateNode.is_before());
416 assert!(!HookPoint::AfterCreateNode.is_before());
417 }
418
419 #[test]
420 fn test_hook_registry() {
421 let mut registry = HookRegistry::new();
422 let plugin: Arc<dyn Plugin> = Arc::new(MockPlugin::new("test", false));
423
424 registry.register_hook(HookPoint::BeforeCreateNode, Arc::clone(&plugin));
425 assert_eq!(registry.count_plugins(HookPoint::BeforeCreateNode), 1);
426
427 let plugins = registry.get_plugins(HookPoint::BeforeCreateNode);
428 assert_eq!(plugins.len(), 1);
429
430 registry.unregister_hook(HookPoint::BeforeCreateNode, "test");
431 assert_eq!(registry.count_plugins(HookPoint::BeforeCreateNode), 0);
432 }
433
434 #[tokio::test]
435 async fn test_hook_executor_success() {
436 let executor = HookExecutor::new();
437 let plugins: Vec<Arc<dyn Plugin>> = vec![
438 Arc::new(MockPlugin::new("plugin1", false)),
439 Arc::new(MockPlugin::new("plugin2", false)),
440 ];
441
442 let context = PluginContext::new("test", serde_json::json!({}));
443
444 let result = executor
445 .execute_before(HookPoint::BeforeCreateNode, &plugins, &context)
446 .await;
447
448 assert!(result.is_ok());
449 }
450
451 #[tokio::test]
452 async fn test_hook_executor_fail_fast() {
453 let executor = HookExecutor::new();
454 let plugins: Vec<Arc<dyn Plugin>> = vec![
455 Arc::new(MockPlugin::new("plugin1", false)),
456 Arc::new(MockPlugin::new("plugin2", true)),
457 Arc::new(MockPlugin::new("plugin3", false)),
458 ];
459
460 let context = PluginContext::new("test", serde_json::json!({}));
461
462 let result = executor
463 .execute_before(HookPoint::BeforeCreateNode, &plugins, &context)
464 .await;
465
466 assert!(result.is_err());
467 }
468
469 #[tokio::test]
470 async fn test_hook_executor_without_fail_fast() {
471 let executor = HookExecutor::without_fail_fast();
472 let plugins: Vec<Arc<dyn Plugin>> = vec![
473 Arc::new(MockPlugin::new("plugin1", false)),
474 Arc::new(MockPlugin::new("plugin2", true)),
475 Arc::new(MockPlugin::new("plugin3", false)),
476 ];
477
478 let context = PluginContext::new("test", serde_json::json!({}));
479
480 let result = executor
481 .execute_before(HookPoint::BeforeCreateNode, &plugins, &context)
482 .await;
483
484 assert!(result.is_err());
486 }
487}