Skip to main content

nexus_memory_hooks/
extractor.rs

1//! Multi-layer extractor combining all detection methods
2//!
3//! This module provides the main extraction orchestrator that combines
4//! all four detection layers for maximum reliability.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::{broadcast, RwLock};
9
10use crate::base::{AgentHook, HookResult};
11use crate::buffer::{BufferData, PersistentBuffer};
12use crate::detector::InactivityDetector;
13use crate::error::{HookError, Result};
14use crate::monitor::{MonitorEvent, SessionMonitor};
15use crate::session::SessionContext;
16use crate::signal::{SignalEvent, SignalHandler};
17use crate::types::{AgentType, ExtractionSource};
18
19/// Extraction statistics
20#[derive(Debug, Clone, Default)]
21pub struct ExtractionStats {
22    pub total_extractions: u64,
23    pub native_extractions: u64,
24    pub monitor_extractions: u64,
25    pub inactivity_extractions: u64,
26    pub buffer_recoveries: u64,
27    pub signal_extractions: u64,
28    pub failed_extractions: u64,
29}
30
31impl ExtractionStats {
32    pub fn success_rate(&self) -> f32 {
33        if self.total_extractions == 0 {
34            1.0
35        } else {
36            let successful = self.total_extractions - self.failed_extractions;
37            successful as f32 / self.total_extractions as f32
38        }
39    }
40}
41
42/// Multi-layer extractor for maximum reliability
43///
44/// This combines all four detection layers:
45/// 1. **Native Hooks** (100%): Direct agent integration
46/// 2. **Session Monitor** (95%): Process monitoring
47/// 3. **Inactivity Detector** (90%): Timeout detection
48/// 4. **Persistent Buffer** (99%): Crash recovery
49///
50/// # Example
51///
52/// ```ignore
53/// use nexus_memory_hooks::{HookFactory, MultiLayerExtractor};
54/// use std::sync::Arc;
55///
56/// #[tokio::main]
57/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
58///     let factory = HookFactory::new();
59///     let hook = factory.create_hook("claude-code")?;
60///
61///     let extractor = MultiLayerExtractor::new()
62///         .with_hook(hook)
63///         .await?;
64///
65///     // Start monitoring
66///     extractor.start().await?;
67///
68///     // Extract context
69///     let context = extractor.extract().await?;
70///     println!("Extracted: {:?}", context);
71///
72///     // Stop and flush
73///     extractor.stop().await?;
74///
75///     Ok(())
76/// }
77/// ```
78pub struct MultiLayerExtractor {
79    /// Agent hooks by type
80    hooks: Arc<RwLock<HashMap<String, Box<dyn AgentHook>>>>,
81
82    /// Persistent buffer
83    buffer: PersistentBuffer,
84
85    /// Session monitor
86    monitor: SessionMonitor,
87
88    /// Inactivity detector
89    inactivity_detector: InactivityDetector,
90
91    /// Signal handler
92    signal_handler: SignalHandler,
93
94    /// Event sender
95    event_sender: broadcast::Sender<ExtractionEvent>,
96
97    /// Statistics
98    stats: Arc<RwLock<ExtractionStats>>,
99
100    /// Whether extraction is active
101    active: Arc<RwLock<bool>>,
102}
103
104/// Events emitted by the extractor
105#[derive(Debug, Clone)]
106pub enum ExtractionEvent {
107    /// Extraction started
108    Started {
109        agent_type: String,
110        source: ExtractionSource,
111    },
112
113    /// Extraction completed
114    Completed {
115        agent_type: String,
116        source: ExtractionSource,
117        context: Box<SessionContext>,
118    },
119
120    /// Extraction failed
121    Failed {
122        agent_type: String,
123        source: ExtractionSource,
124        error: String,
125    },
126
127    /// Buffer recovered
128    BufferRecovered { agent_type: String, entries: usize },
129}
130
131impl MultiLayerExtractor {
132    /// Create a new multi-layer extractor
133    pub fn new() -> Result<Self> {
134        let buffer = PersistentBuffer::new(None)?;
135        let (event_sender, _) = broadcast::channel(100);
136
137        Ok(Self {
138            hooks: Arc::new(RwLock::new(HashMap::new())),
139            buffer,
140            monitor: SessionMonitor::new(),
141            inactivity_detector: InactivityDetector::new(),
142            signal_handler: SignalHandler::new(),
143            event_sender,
144            stats: Arc::new(RwLock::new(ExtractionStats::default())),
145            active: Arc::new(RwLock::new(false)),
146        })
147    }
148
149    /// Add a hook to the extractor
150    pub async fn with_hook(self, hook: Box<dyn AgentHook>) -> Result<Self> {
151        let agent_type = hook.agent_type().to_string();
152
153        // Install native hook callback
154        let event_sender = self.event_sender.clone();
155        let agent_type_clone = agent_type.clone();
156
157        let _callback = Arc::new(move |ctx: SessionContext| {
158            let _ = event_sender.send(ExtractionEvent::Completed {
159                agent_type: agent_type_clone.clone(),
160                source: ExtractionSource::NativeHook("session_end".to_string()),
161                context: Box::new(ctx),
162            });
163        });
164
165        // We need mutable access to install the hook
166        {
167            let mut hooks = self.hooks.write().await;
168            hooks.insert(agent_type.clone(), hook);
169        }
170
171        Ok(self)
172    }
173
174    /// Subscribe to extraction events
175    pub fn subscribe(&self) -> broadcast::Receiver<ExtractionEvent> {
176        self.event_sender.subscribe()
177    }
178
179    /// Start all detection layers
180    pub async fn start(&self) -> Result<()> {
181        let mut active = self.active.write().await;
182        if *active {
183            return Ok(());
184        }
185        *active = true;
186        drop(active);
187
188        // Get agent types to monitor
189        let agent_types: Vec<String> = {
190            let hooks = self.hooks.read().await;
191            hooks.keys().cloned().collect()
192        };
193
194        // Convert to AgentType for monitor
195        let agent_types_enum: Vec<AgentType> = agent_types
196            .iter()
197            .filter_map(|s| AgentType::parse(s))
198            .collect();
199
200        // Start session monitor
201        self.monitor.start_monitoring(agent_types_enum).await;
202
203        // Start inactivity detector
204        self.inactivity_detector
205            .start_monitoring(agent_types.clone())
206            .await;
207
208        // Install signal handlers
209        self.signal_handler.install().await?;
210
211        // Subscribe to monitor events
212        let event_sender = self.event_sender.clone();
213        let stats = self.stats.clone();
214        let mut monitor_rx = self.monitor.subscribe();
215
216        tokio::spawn(async move {
217            while let Ok(event) = monitor_rx.recv().await {
218                match event {
219                    MonitorEvent::SessionEnded {
220                        agent_type,
221                        reason: _,
222                        ..
223                    } => {
224                        let _ = event_sender.send(ExtractionEvent::Started {
225                            agent_type: agent_type.clone(),
226                            source: ExtractionSource::ProcessMonitor,
227                        });
228
229                        let mut stats = stats.write().await;
230                        stats.total_extractions += 1;
231                        stats.monitor_extractions += 1;
232                    }
233                    MonitorEvent::InactivityDetected { agent_type, .. } => {
234                        let _ = event_sender.send(ExtractionEvent::Started {
235                            agent_type: agent_type.clone(),
236                            source: ExtractionSource::InactivityTimeout,
237                        });
238
239                        let mut stats = stats.write().await;
240                        stats.total_extractions += 1;
241                        stats.inactivity_extractions += 1;
242                    }
243                    _ => {}
244                }
245            }
246        });
247
248        // Subscribe to signal events
249        let _event_sender = self.event_sender.clone();
250        let stats = self.stats.clone();
251        let mut signal_rx = self.signal_handler.subscribe();
252
253        tokio::spawn(async move {
254            while let Ok(signal) = signal_rx.recv().await {
255                let _source = match signal {
256                    SignalEvent::Interrupt => ExtractionSource::SignalHandler("SIGINT".to_string()),
257                    SignalEvent::Terminate => {
258                        ExtractionSource::SignalHandler("SIGTERM".to_string())
259                    }
260                    _ => continue,
261                };
262
263                let mut stats = stats.write().await;
264                stats.total_extractions += 1;
265                stats.signal_extractions += 1;
266            }
267        });
268
269        // Start buffering for all agents
270        for agent_type in &agent_types {
271            self.buffer.start_buffering(agent_type).await?;
272        }
273
274        tracing::info!("Multi-layer extractor started");
275
276        Ok(())
277    }
278
279    /// Stop all detection layers
280    pub async fn stop(&self) -> Result<()> {
281        let mut active = self.active.write().await;
282        *active = false;
283
284        self.monitor.stop_monitoring().await;
285        self.inactivity_detector.stop_monitoring().await;
286
287        // Flush all buffers
288        self.buffer.flush_all().await?;
289
290        tracing::info!("Multi-layer extractor stopped");
291
292        Ok(())
293    }
294
295    /// Extract context for an agent
296    pub async fn extract(&self, agent_type: &str) -> Result<SessionContext> {
297        // Try native extraction first
298        let native_result = self.try_native_extraction(agent_type).await;
299
300        if let Ok(context) = native_result {
301            // Store to buffer
302            self.buffer
303                .buffer_context(agent_type, context.clone(), "extraction")
304                .await?;
305            return Ok(context);
306        }
307
308        // Try buffer recovery
309        if let Some(data) = self.buffer.recover_buffer(agent_type).await? {
310            let context = self.buffer_data_to_context(data);
311            let _ = self.event_sender.send(ExtractionEvent::BufferRecovered {
312                agent_type: agent_type.to_string(),
313                entries: context.insights.len(), // Use insights count as entry count
314            });
315
316            let mut stats = self.stats.write().await;
317            stats.buffer_recoveries += 1;
318
319            return Ok(context);
320        }
321
322        // Fallback to minimal context
323        Ok(SessionContext::new(agent_type)
324            .with_source("fallback")
325            .with_reliability(0.5))
326    }
327
328    /// Try native extraction
329    async fn try_native_extraction(&self, agent_type: &str) -> Result<SessionContext> {
330        let hooks = self.hooks.read().await;
331
332        if let Some(hook) = hooks.get(agent_type) {
333            // Check activity first
334            let activity = hook.detect_session_activity().await?;
335
336            if activity.is_active {
337                return hook.extract_session_context().await;
338            }
339        }
340
341        Err(HookError::SessionNotActive)
342    }
343
344    /// Convert buffer data to session context
345    fn buffer_data_to_context(&self, data: BufferData) -> SessionContext {
346        let mut context = SessionContext::new(&data.agent_type)
347            .with_source("buffer_recovery")
348            .with_reliability(0.99);
349
350        for entry in data.entries {
351            context.insights.push(format!(
352                "[{}] {:?}",
353                entry.context_type,
354                entry.context.to_memory_content()
355            ));
356        }
357
358        context
359    }
360
361    /// Get extraction statistics
362    pub async fn stats(&self) -> ExtractionStats {
363        self.stats.read().await.clone()
364    }
365
366    /// Check if extraction is active
367    pub async fn is_active(&self) -> bool {
368        *self.active.read().await
369    }
370
371    /// Manually trigger extraction for an agent
372    pub async fn trigger_extraction(&self, agent_type: &str) -> Result<HookResult> {
373        let _ = self.event_sender.send(ExtractionEvent::Started {
374            agent_type: agent_type.to_string(),
375            source: ExtractionSource::Manual,
376        });
377
378        match self.extract(agent_type).await {
379            Ok(context) => {
380                let _ = self.event_sender.send(ExtractionEvent::Completed {
381                    agent_type: agent_type.to_string(),
382                    source: ExtractionSource::Manual,
383                    context: Box::new(context.clone()),
384                });
385
386                let mut stats = self.stats.write().await;
387                stats.total_extractions += 1;
388
389                Ok(HookResult::success_with_context(
390                    agent_type,
391                    ExtractionSource::Manual,
392                    context,
393                ))
394            }
395            Err(e) => {
396                let _ = self.event_sender.send(ExtractionEvent::Failed {
397                    agent_type: agent_type.to_string(),
398                    source: ExtractionSource::Manual,
399                    error: e.to_string(),
400                });
401
402                let mut stats = self.stats.write().await;
403                stats.total_extractions += 1;
404                stats.failed_extractions += 1;
405
406                Ok(HookResult::failure(
407                    agent_type,
408                    ExtractionSource::Manual,
409                    e.to_string(),
410                ))
411            }
412        }
413    }
414
415    /// Check for buffered data to recover
416    pub async fn check_for_recovery(&self) -> Result<Vec<(String, BufferData)>> {
417        let hooks = self.hooks.read().await;
418        let mut recovered = Vec::new();
419
420        for agent_type in hooks.keys() {
421            if let Some(data) = self.buffer.recover_buffer(agent_type).await? {
422                recovered.push((agent_type.clone(), data));
423            }
424        }
425
426        Ok(recovered)
427    }
428
429    /// Clear buffer for an agent
430    pub async fn clear_buffer(&self, agent_type: &str) -> Result<()> {
431        self.buffer.clear_buffer(agent_type).await
432    }
433}
434
435impl Default for MultiLayerExtractor {
436    fn default() -> Self {
437        Self::new().expect("Failed to create extractor")
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[tokio::test]
446    async fn test_extractor_new() {
447        let extractor = MultiLayerExtractor::new().unwrap();
448        assert!(!extractor.is_active().await);
449    }
450
451    #[tokio::test]
452    async fn test_extractor_stats() {
453        let extractor = MultiLayerExtractor::new().unwrap();
454        let stats = extractor.stats().await;
455
456        assert_eq!(stats.total_extractions, 0);
457        assert_eq!(stats.success_rate(), 1.0);
458    }
459
460    #[tokio::test]
461    async fn test_extractor_subscribe() {
462        let extractor = MultiLayerExtractor::new().unwrap();
463        let receiver = extractor.subscribe();
464
465        // Should be able to subscribe without error
466        drop(receiver);
467    }
468
469    #[test]
470    fn test_extraction_stats_success_rate() {
471        let mut stats = ExtractionStats::default();
472
473        assert_eq!(stats.success_rate(), 1.0);
474
475        stats.total_extractions = 10;
476        stats.failed_extractions = 2;
477
478        assert!((stats.success_rate() - 0.8).abs() < 0.001);
479    }
480}