llm_memory_graph/plugin/
hooks.rs

1//! Hook execution framework
2//!
3//! This module provides infrastructure for managing and executing plugin hooks
4//! at specific points in the system's execution flow.
5
6use super::{Plugin, PluginContext, PluginError};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tracing::{debug, warn};
10
11/// Hook point identifier
12///
13/// Represents specific points in the execution flow where plugins can be invoked.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum HookPoint {
16    /// Before creating a node
17    BeforeCreateNode,
18    /// After creating a node
19    AfterCreateNode,
20    /// Before creating a session
21    BeforeCreateSession,
22    /// After creating a session
23    AfterCreateSession,
24    /// Before executing a query
25    BeforeQuery,
26    /// After executing a query
27    AfterQuery,
28    /// Before creating an edge
29    BeforeCreateEdge,
30    /// After creating an edge
31    AfterCreateEdge,
32    /// Before updating a node
33    BeforeUpdateNode,
34    /// After updating a node
35    AfterUpdateNode,
36    /// Before deleting a node
37    BeforeDeleteNode,
38    /// After deleting a node
39    AfterDeleteNode,
40    /// Before deleting a session
41    BeforeDeleteSession,
42    /// After deleting a session
43    AfterDeleteSession,
44}
45
46impl HookPoint {
47    /// Get the hook name as a string
48    pub fn as_str(&self) -> &'static str {
49        match self {
50            Self::BeforeCreateNode => "before_create_node",
51            Self::AfterCreateNode => "after_create_node",
52            Self::BeforeCreateSession => "before_create_session",
53            Self::AfterCreateSession => "after_create_session",
54            Self::BeforeQuery => "before_query",
55            Self::AfterQuery => "after_query",
56            Self::BeforeCreateEdge => "before_create_edge",
57            Self::AfterCreateEdge => "after_create_edge",
58            Self::BeforeUpdateNode => "before_update_node",
59            Self::AfterUpdateNode => "after_update_node",
60            Self::BeforeDeleteNode => "before_delete_node",
61            Self::AfterDeleteNode => "after_delete_node",
62            Self::BeforeDeleteSession => "before_delete_session",
63            Self::AfterDeleteSession => "after_delete_session",
64        }
65    }
66
67    /// Check if this is a "before" hook
68    pub fn is_before(&self) -> bool {
69        matches!(
70            self,
71            Self::BeforeCreateNode
72                | Self::BeforeCreateSession
73                | Self::BeforeQuery
74                | Self::BeforeCreateEdge
75                | Self::BeforeUpdateNode
76                | Self::BeforeDeleteNode
77                | Self::BeforeDeleteSession
78        )
79    }
80
81    /// Check if this is an "after" hook
82    pub fn is_after(&self) -> bool {
83        !self.is_before()
84    }
85
86    /// Get all hook points
87    pub fn all() -> Vec<Self> {
88        vec![
89            Self::BeforeCreateNode,
90            Self::AfterCreateNode,
91            Self::BeforeCreateSession,
92            Self::AfterCreateSession,
93            Self::BeforeQuery,
94            Self::AfterQuery,
95            Self::BeforeCreateEdge,
96            Self::AfterCreateEdge,
97            Self::BeforeUpdateNode,
98            Self::AfterUpdateNode,
99            Self::BeforeDeleteNode,
100            Self::AfterDeleteNode,
101            Self::BeforeDeleteSession,
102            Self::AfterDeleteSession,
103        ]
104    }
105}
106
107impl std::fmt::Display for HookPoint {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        write!(f, "{}", self.as_str())
110    }
111}
112
113/// Hook registry
114///
115/// Maintains mappings between hook points and the plugins that should be
116/// invoked at those points. This allows for efficient hook execution.
117pub struct HookRegistry {
118    hooks: HashMap<HookPoint, Vec<Arc<dyn Plugin>>>,
119}
120
121impl HookRegistry {
122    /// Create a new hook registry
123    pub fn new() -> Self {
124        Self {
125            hooks: HashMap::new(),
126        }
127    }
128
129    /// Register a plugin for a specific hook point
130    pub fn register_hook(&mut self, hook: HookPoint, plugin: Arc<dyn Plugin>) {
131        self.hooks.entry(hook).or_default().push(plugin);
132    }
133
134    /// Unregister a plugin from a specific hook point
135    pub fn unregister_hook(&mut self, hook: HookPoint, plugin_name: &str) {
136        if let Some(plugins) = self.hooks.get_mut(&hook) {
137            plugins.retain(|p| p.metadata().name != plugin_name);
138        }
139    }
140
141    /// Unregister a plugin from all hook points
142    pub fn unregister_plugin(&mut self, plugin_name: &str) {
143        for plugins in self.hooks.values_mut() {
144            plugins.retain(|p| p.metadata().name != plugin_name);
145        }
146    }
147
148    /// Get all plugins registered for a hook point
149    pub fn get_plugins(&self, hook: HookPoint) -> Vec<Arc<dyn Plugin>> {
150        self.hooks.get(&hook).cloned().unwrap_or_default()
151    }
152
153    /// Get the number of plugins registered for a hook point
154    pub fn count_plugins(&self, hook: HookPoint) -> usize {
155        self.hooks.get(&hook).map(Vec::len).unwrap_or(0)
156    }
157
158    /// Clear all hook registrations
159    pub fn clear(&mut self) {
160        self.hooks.clear();
161    }
162
163    /// Get statistics about hook registrations
164    pub fn stats(&self) -> HashMap<HookPoint, usize> {
165        self.hooks
166            .iter()
167            .map(|(hook, plugins)| (*hook, plugins.len()))
168            .collect()
169    }
170}
171
172impl Default for HookRegistry {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178/// Hook executor
179///
180/// Responsible for executing plugin hooks at specific points in the system.
181/// Handles error propagation, logging, and execution order.
182pub struct HookExecutor {
183    /// Whether to stop execution on first error (before hooks only)
184    fail_fast: bool,
185    /// Whether to collect timing metrics
186    collect_metrics: bool,
187}
188
189impl HookExecutor {
190    /// Create a new hook executor
191    pub fn new() -> Self {
192        Self {
193            fail_fast: true,
194            collect_metrics: false,
195        }
196    }
197
198    /// Create a hook executor with fail-fast disabled
199    ///
200    /// When fail-fast is disabled, all plugins will be executed even if
201    /// some fail. Useful for after-hooks where you want to ensure all
202    /// plugins get a chance to run.
203    pub fn without_fail_fast() -> Self {
204        Self {
205            fail_fast: false,
206            collect_metrics: false,
207        }
208    }
209
210    /// Enable metrics collection
211    pub fn with_metrics(mut self) -> Self {
212        self.collect_metrics = true;
213        self
214    }
215
216    /// Execute before hooks
217    ///
218    /// Executes all plugins at the specified hook point. If fail_fast is enabled,
219    /// stops at the first error. Otherwise, collects all errors and returns them.
220    pub async fn execute_before(
221        &self,
222        hook: HookPoint,
223        plugins: &[Arc<dyn Plugin>],
224        context: &PluginContext,
225    ) -> Result<(), PluginError> {
226        debug!("Executing {} with {} plugins", hook, plugins.len());
227
228        let mut errors = Vec::new();
229
230        for plugin in plugins {
231            let plugin_name = &plugin.metadata().name;
232            debug!("Executing hook {} for plugin {}", hook, plugin_name);
233
234            let start = std::time::Instant::now();
235
236            match plugin.before_hook(hook.as_str(), context).await {
237                Ok(()) => {
238                    if self.collect_metrics {
239                        let duration = start.elapsed();
240                        debug!(
241                            "Plugin {} completed {} in {:?}",
242                            plugin_name, hook, duration
243                        );
244                    }
245                }
246                Err(e) => {
247                    warn!("Plugin {} failed on {}: {}", plugin_name, hook, e);
248
249                    if self.fail_fast {
250                        return Err(e);
251                    }
252                    errors.push((plugin_name.clone(), e));
253                }
254            }
255        }
256
257        if !errors.is_empty() {
258            let error_msg = errors
259                .iter()
260                .map(|(name, e)| format!("{}: {}", name, e))
261                .collect::<Vec<_>>()
262                .join("; ");
263
264            return Err(PluginError::HookFailed(format!(
265                "Multiple plugins failed: {}",
266                error_msg
267            )));
268        }
269
270        Ok(())
271    }
272
273    /// Execute after hooks
274    ///
275    /// Executes all plugins at the specified hook point. After hooks never
276    /// fail the operation - errors are logged but execution continues.
277    pub async fn execute_after(
278        &self,
279        hook: HookPoint,
280        plugins: &[Arc<dyn Plugin>],
281        context: &PluginContext,
282    ) -> Result<(), PluginError> {
283        debug!("Executing {} with {} plugins", hook, plugins.len());
284
285        for plugin in plugins {
286            let plugin_name = &plugin.metadata().name;
287            debug!("Executing hook {} for plugin {}", hook, plugin_name);
288
289            let start = std::time::Instant::now();
290
291            match plugin.after_hook(hook.as_str(), context).await {
292                Ok(()) => {
293                    if self.collect_metrics {
294                        let duration = start.elapsed();
295                        debug!(
296                            "Plugin {} completed {} in {:?}",
297                            plugin_name, hook, duration
298                        );
299                    }
300                }
301                Err(e) => {
302                    // After hooks should not fail the operation
303                    warn!(
304                        "Plugin {} failed on after hook {}: {}",
305                        plugin_name, hook, e
306                    );
307                }
308            }
309        }
310
311        Ok(())
312    }
313
314    /// Execute a hook point
315    ///
316    /// Automatically determines whether to use before or after hook semantics
317    /// based on the hook point.
318    pub async fn execute(
319        &self,
320        hook: HookPoint,
321        plugins: &[Arc<dyn Plugin>],
322        context: &PluginContext,
323    ) -> Result<(), PluginError> {
324        if hook.is_before() {
325            self.execute_before(hook, plugins, context).await
326        } else {
327            self.execute_after(hook, plugins, context).await
328        }
329    }
330}
331
332impl Default for HookExecutor {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338/// Hook execution result with timing information
339#[derive(Debug)]
340pub struct HookExecutionResult {
341    /// Hook point that was executed
342    pub hook: HookPoint,
343    /// Number of plugins executed
344    pub plugins_executed: usize,
345    /// Total execution time
346    pub total_duration: std::time::Duration,
347    /// Individual plugin execution times
348    pub plugin_durations: HashMap<String, std::time::Duration>,
349    /// Any errors that occurred
350    pub errors: Vec<(String, String)>,
351}
352
353impl HookExecutionResult {
354    /// Check if the execution was successful
355    pub fn is_success(&self) -> bool {
356        self.errors.is_empty()
357    }
358
359    /// Get the average execution time per plugin
360    pub fn average_duration(&self) -> std::time::Duration {
361        if self.plugins_executed == 0 {
362            return std::time::Duration::ZERO;
363        }
364        self.total_duration / self.plugins_executed as u32
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::plugin::{PluginBuilder, PluginMetadata};
372    use async_trait::async_trait;
373
374    struct MockPlugin {
375        metadata: PluginMetadata,
376        should_fail: bool,
377    }
378
379    impl MockPlugin {
380        fn new(name: &str, should_fail: bool) -> Self {
381            let metadata = PluginBuilder::new(name, "1.0.0")
382                .author("Test")
383                .description("Test plugin")
384                .build();
385            Self {
386                metadata,
387                should_fail,
388            }
389        }
390    }
391
392    #[async_trait]
393    impl Plugin for MockPlugin {
394        fn metadata(&self) -> &PluginMetadata {
395            &self.metadata
396        }
397
398        async fn before_create_node(&self, _context: &PluginContext) -> Result<(), PluginError> {
399            if self.should_fail {
400                Err(PluginError::HookFailed("Test failure".to_string()))
401            } else {
402                Ok(())
403            }
404        }
405    }
406
407    #[test]
408    fn test_hook_point_as_str() {
409        assert_eq!(HookPoint::BeforeCreateNode.as_str(), "before_create_node");
410        assert_eq!(HookPoint::AfterCreateNode.as_str(), "after_create_node");
411    }
412
413    #[test]
414    fn test_hook_point_is_before() {
415        assert!(HookPoint::BeforeCreateNode.is_before());
416        assert!(!HookPoint::AfterCreateNode.is_before());
417    }
418
419    #[test]
420    fn test_hook_registry() {
421        let mut registry = HookRegistry::new();
422        let plugin: Arc<dyn Plugin> = Arc::new(MockPlugin::new("test", false));
423
424        registry.register_hook(HookPoint::BeforeCreateNode, Arc::clone(&plugin));
425        assert_eq!(registry.count_plugins(HookPoint::BeforeCreateNode), 1);
426
427        let plugins = registry.get_plugins(HookPoint::BeforeCreateNode);
428        assert_eq!(plugins.len(), 1);
429
430        registry.unregister_hook(HookPoint::BeforeCreateNode, "test");
431        assert_eq!(registry.count_plugins(HookPoint::BeforeCreateNode), 0);
432    }
433
434    #[tokio::test]
435    async fn test_hook_executor_success() {
436        let executor = HookExecutor::new();
437        let plugins: Vec<Arc<dyn Plugin>> = vec![
438            Arc::new(MockPlugin::new("plugin1", false)),
439            Arc::new(MockPlugin::new("plugin2", false)),
440        ];
441
442        let context = PluginContext::new("test", serde_json::json!({}));
443
444        let result = executor
445            .execute_before(HookPoint::BeforeCreateNode, &plugins, &context)
446            .await;
447
448        assert!(result.is_ok());
449    }
450
451    #[tokio::test]
452    async fn test_hook_executor_fail_fast() {
453        let executor = HookExecutor::new();
454        let plugins: Vec<Arc<dyn Plugin>> = vec![
455            Arc::new(MockPlugin::new("plugin1", false)),
456            Arc::new(MockPlugin::new("plugin2", true)),
457            Arc::new(MockPlugin::new("plugin3", false)),
458        ];
459
460        let context = PluginContext::new("test", serde_json::json!({}));
461
462        let result = executor
463            .execute_before(HookPoint::BeforeCreateNode, &plugins, &context)
464            .await;
465
466        assert!(result.is_err());
467    }
468
469    #[tokio::test]
470    async fn test_hook_executor_without_fail_fast() {
471        let executor = HookExecutor::without_fail_fast();
472        let plugins: Vec<Arc<dyn Plugin>> = vec![
473            Arc::new(MockPlugin::new("plugin1", false)),
474            Arc::new(MockPlugin::new("plugin2", true)),
475            Arc::new(MockPlugin::new("plugin3", false)),
476        ];
477
478        let context = PluginContext::new("test", serde_json::json!({}));
479
480        let result = executor
481            .execute_before(HookPoint::BeforeCreateNode, &plugins, &context)
482            .await;
483
484        // Should fail but only after executing all plugins
485        assert!(result.is_err());
486    }
487}