1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::{broadcast, RwLock};
9
10use crate::base::{AgentHook, HookResult};
11use crate::buffer::{BufferData, PersistentBuffer};
12use crate::detector::InactivityDetector;
13use crate::error::{HookError, Result};
14use crate::monitor::{MonitorEvent, SessionMonitor};
15use crate::session::SessionContext;
16use crate::signal::{SignalEvent, SignalHandler};
17use crate::types::{AgentType, ExtractionSource};
18
19#[derive(Debug, Clone, Default)]
21pub struct ExtractionStats {
22 pub total_extractions: u64,
23 pub native_extractions: u64,
24 pub monitor_extractions: u64,
25 pub inactivity_extractions: u64,
26 pub buffer_recoveries: u64,
27 pub signal_extractions: u64,
28 pub failed_extractions: u64,
29}
30
31impl ExtractionStats {
32 pub fn success_rate(&self) -> f32 {
33 if self.total_extractions == 0 {
34 1.0
35 } else {
36 let successful = self.total_extractions - self.failed_extractions;
37 successful as f32 / self.total_extractions as f32
38 }
39 }
40}
41
42pub struct MultiLayerExtractor {
79 hooks: Arc<RwLock<HashMap<String, Box<dyn AgentHook>>>>,
81
82 buffer: PersistentBuffer,
84
85 monitor: SessionMonitor,
87
88 inactivity_detector: InactivityDetector,
90
91 signal_handler: SignalHandler,
93
94 event_sender: broadcast::Sender<ExtractionEvent>,
96
97 stats: Arc<RwLock<ExtractionStats>>,
99
100 active: Arc<RwLock<bool>>,
102}
103
104#[derive(Debug, Clone)]
106pub enum ExtractionEvent {
107 Started {
109 agent_type: String,
110 source: ExtractionSource,
111 },
112
113 Completed {
115 agent_type: String,
116 source: ExtractionSource,
117 context: Box<SessionContext>,
118 },
119
120 Failed {
122 agent_type: String,
123 source: ExtractionSource,
124 error: String,
125 },
126
127 BufferRecovered { agent_type: String, entries: usize },
129}
130
131impl MultiLayerExtractor {
132 pub fn new() -> Result<Self> {
134 let buffer = PersistentBuffer::new(None)?;
135 let (event_sender, _) = broadcast::channel(100);
136
137 Ok(Self {
138 hooks: Arc::new(RwLock::new(HashMap::new())),
139 buffer,
140 monitor: SessionMonitor::new(),
141 inactivity_detector: InactivityDetector::new(),
142 signal_handler: SignalHandler::new(),
143 event_sender,
144 stats: Arc::new(RwLock::new(ExtractionStats::default())),
145 active: Arc::new(RwLock::new(false)),
146 })
147 }
148
149 pub async fn with_hook(self, hook: Box<dyn AgentHook>) -> Result<Self> {
151 let agent_type = hook.agent_type().to_string();
152
153 let event_sender = self.event_sender.clone();
155 let agent_type_clone = agent_type.clone();
156
157 let _callback = Arc::new(move |ctx: SessionContext| {
158 let _ = event_sender.send(ExtractionEvent::Completed {
159 agent_type: agent_type_clone.clone(),
160 source: ExtractionSource::NativeHook("session_end".to_string()),
161 context: Box::new(ctx),
162 });
163 });
164
165 {
167 let mut hooks = self.hooks.write().await;
168 hooks.insert(agent_type.clone(), hook);
169 }
170
171 Ok(self)
172 }
173
174 pub fn subscribe(&self) -> broadcast::Receiver<ExtractionEvent> {
176 self.event_sender.subscribe()
177 }
178
179 pub async fn start(&self) -> Result<()> {
181 let mut active = self.active.write().await;
182 if *active {
183 return Ok(());
184 }
185 *active = true;
186 drop(active);
187
188 let agent_types: Vec<String> = {
190 let hooks = self.hooks.read().await;
191 hooks.keys().cloned().collect()
192 };
193
194 let agent_types_enum: Vec<AgentType> = agent_types
196 .iter()
197 .filter_map(|s| AgentType::parse(s))
198 .collect();
199
200 self.monitor.start_monitoring(agent_types_enum).await;
202
203 self.inactivity_detector
205 .start_monitoring(agent_types.clone())
206 .await;
207
208 self.signal_handler.install().await?;
210
211 let event_sender = self.event_sender.clone();
213 let stats = self.stats.clone();
214 let mut monitor_rx = self.monitor.subscribe();
215
216 tokio::spawn(async move {
217 while let Ok(event) = monitor_rx.recv().await {
218 match event {
219 MonitorEvent::SessionEnded {
220 agent_type,
221 reason: _,
222 ..
223 } => {
224 let _ = event_sender.send(ExtractionEvent::Started {
225 agent_type: agent_type.clone(),
226 source: ExtractionSource::ProcessMonitor,
227 });
228
229 let mut stats = stats.write().await;
230 stats.total_extractions += 1;
231 stats.monitor_extractions += 1;
232 }
233 MonitorEvent::InactivityDetected { agent_type, .. } => {
234 let _ = event_sender.send(ExtractionEvent::Started {
235 agent_type: agent_type.clone(),
236 source: ExtractionSource::InactivityTimeout,
237 });
238
239 let mut stats = stats.write().await;
240 stats.total_extractions += 1;
241 stats.inactivity_extractions += 1;
242 }
243 _ => {}
244 }
245 }
246 });
247
248 let _event_sender = self.event_sender.clone();
250 let stats = self.stats.clone();
251 let mut signal_rx = self.signal_handler.subscribe();
252
253 tokio::spawn(async move {
254 while let Ok(signal) = signal_rx.recv().await {
255 let _source = match signal {
256 SignalEvent::Interrupt => ExtractionSource::SignalHandler("SIGINT".to_string()),
257 SignalEvent::Terminate => {
258 ExtractionSource::SignalHandler("SIGTERM".to_string())
259 }
260 _ => continue,
261 };
262
263 let mut stats = stats.write().await;
264 stats.total_extractions += 1;
265 stats.signal_extractions += 1;
266 }
267 });
268
269 for agent_type in &agent_types {
271 self.buffer.start_buffering(agent_type).await?;
272 }
273
274 tracing::info!("Multi-layer extractor started");
275
276 Ok(())
277 }
278
279 pub async fn stop(&self) -> Result<()> {
281 let mut active = self.active.write().await;
282 *active = false;
283
284 self.monitor.stop_monitoring().await;
285 self.inactivity_detector.stop_monitoring().await;
286
287 self.buffer.flush_all().await?;
289
290 tracing::info!("Multi-layer extractor stopped");
291
292 Ok(())
293 }
294
295 pub async fn extract(&self, agent_type: &str) -> Result<SessionContext> {
297 let native_result = self.try_native_extraction(agent_type).await;
299
300 if let Ok(context) = native_result {
301 self.buffer
303 .buffer_context(agent_type, context.clone(), "extraction")
304 .await?;
305 return Ok(context);
306 }
307
308 if let Some(data) = self.buffer.recover_buffer(agent_type).await? {
310 let context = self.buffer_data_to_context(data);
311 let _ = self.event_sender.send(ExtractionEvent::BufferRecovered {
312 agent_type: agent_type.to_string(),
313 entries: context.insights.len(), });
315
316 let mut stats = self.stats.write().await;
317 stats.buffer_recoveries += 1;
318
319 return Ok(context);
320 }
321
322 Ok(SessionContext::new(agent_type)
324 .with_source("fallback")
325 .with_reliability(0.5))
326 }
327
328 async fn try_native_extraction(&self, agent_type: &str) -> Result<SessionContext> {
330 let hooks = self.hooks.read().await;
331
332 if let Some(hook) = hooks.get(agent_type) {
333 let activity = hook.detect_session_activity().await?;
335
336 if activity.is_active {
337 return hook.extract_session_context().await;
338 }
339 }
340
341 Err(HookError::SessionNotActive)
342 }
343
344 fn buffer_data_to_context(&self, data: BufferData) -> SessionContext {
346 let mut context = SessionContext::new(&data.agent_type)
347 .with_source("buffer_recovery")
348 .with_reliability(0.99);
349
350 for entry in data.entries {
351 context.insights.push(format!(
352 "[{}] {:?}",
353 entry.context_type,
354 entry.context.to_memory_content()
355 ));
356 }
357
358 context
359 }
360
361 pub async fn stats(&self) -> ExtractionStats {
363 self.stats.read().await.clone()
364 }
365
366 pub async fn is_active(&self) -> bool {
368 *self.active.read().await
369 }
370
371 pub async fn trigger_extraction(&self, agent_type: &str) -> Result<HookResult> {
373 let _ = self.event_sender.send(ExtractionEvent::Started {
374 agent_type: agent_type.to_string(),
375 source: ExtractionSource::Manual,
376 });
377
378 match self.extract(agent_type).await {
379 Ok(context) => {
380 let _ = self.event_sender.send(ExtractionEvent::Completed {
381 agent_type: agent_type.to_string(),
382 source: ExtractionSource::Manual,
383 context: Box::new(context.clone()),
384 });
385
386 let mut stats = self.stats.write().await;
387 stats.total_extractions += 1;
388
389 Ok(HookResult::success_with_context(
390 agent_type,
391 ExtractionSource::Manual,
392 context,
393 ))
394 }
395 Err(e) => {
396 let _ = self.event_sender.send(ExtractionEvent::Failed {
397 agent_type: agent_type.to_string(),
398 source: ExtractionSource::Manual,
399 error: e.to_string(),
400 });
401
402 let mut stats = self.stats.write().await;
403 stats.total_extractions += 1;
404 stats.failed_extractions += 1;
405
406 Ok(HookResult::failure(
407 agent_type,
408 ExtractionSource::Manual,
409 e.to_string(),
410 ))
411 }
412 }
413 }
414
415 pub async fn check_for_recovery(&self) -> Result<Vec<(String, BufferData)>> {
417 let hooks = self.hooks.read().await;
418 let mut recovered = Vec::new();
419
420 for agent_type in hooks.keys() {
421 if let Some(data) = self.buffer.recover_buffer(agent_type).await? {
422 recovered.push((agent_type.clone(), data));
423 }
424 }
425
426 Ok(recovered)
427 }
428
429 pub async fn clear_buffer(&self, agent_type: &str) -> Result<()> {
431 self.buffer.clear_buffer(agent_type).await
432 }
433}
434
435impl Default for MultiLayerExtractor {
436 fn default() -> Self {
437 Self::new().expect("Failed to create extractor")
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[tokio::test]
446 async fn test_extractor_new() {
447 let extractor = MultiLayerExtractor::new().unwrap();
448 assert!(!extractor.is_active().await);
449 }
450
451 #[tokio::test]
452 async fn test_extractor_stats() {
453 let extractor = MultiLayerExtractor::new().unwrap();
454 let stats = extractor.stats().await;
455
456 assert_eq!(stats.total_extractions, 0);
457 assert_eq!(stats.success_rate(), 1.0);
458 }
459
460 #[tokio::test]
461 async fn test_extractor_subscribe() {
462 let extractor = MultiLayerExtractor::new().unwrap();
463 let receiver = extractor.subscribe();
464
465 drop(receiver);
467 }
468
469 #[test]
470 fn test_extraction_stats_success_rate() {
471 let mut stats = ExtractionStats::default();
472
473 assert_eq!(stats.success_rate(), 1.0);
474
475 stats.total_extractions = 10;
476 stats.failed_extractions = 2;
477
478 assert!((stats.success_rate() - 0.8).abs() < 0.001);
479 }
480}