use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use crate::base::{AgentHook, HookResult};
use crate::buffer::{BufferData, PersistentBuffer};
use crate::detector::InactivityDetector;
use crate::error::{HookError, Result};
use crate::monitor::{MonitorEvent, SessionMonitor};
use crate::session::SessionContext;
use crate::signal::{SignalEvent, SignalHandler};
use crate::types::{AgentType, ExtractionSource};
#[derive(Debug, Clone, Default)]
pub struct ExtractionStats {
pub total_extractions: u64,
pub native_extractions: u64,
pub monitor_extractions: u64,
pub inactivity_extractions: u64,
pub buffer_recoveries: u64,
pub signal_extractions: u64,
pub failed_extractions: u64,
}
impl ExtractionStats {
pub fn success_rate(&self) -> f32 {
if self.total_extractions == 0 {
1.0
} else {
let successful = self.total_extractions - self.failed_extractions;
successful as f32 / self.total_extractions as f32
}
}
}
pub struct MultiLayerExtractor {
hooks: Arc<RwLock<HashMap<String, Box<dyn AgentHook>>>>,
buffer: PersistentBuffer,
monitor: SessionMonitor,
inactivity_detector: InactivityDetector,
signal_handler: SignalHandler,
event_sender: broadcast::Sender<ExtractionEvent>,
stats: Arc<RwLock<ExtractionStats>>,
active: Arc<RwLock<bool>>,
}
#[derive(Debug, Clone)]
pub enum ExtractionEvent {
Started {
agent_type: String,
source: ExtractionSource,
},
Completed {
agent_type: String,
source: ExtractionSource,
context: Box<SessionContext>,
},
Failed {
agent_type: String,
source: ExtractionSource,
error: String,
},
BufferRecovered { agent_type: String, entries: usize },
}
impl MultiLayerExtractor {
pub fn new() -> Result<Self> {
let buffer = PersistentBuffer::new(None)?;
let (event_sender, _) = broadcast::channel(100);
Ok(Self {
hooks: Arc::new(RwLock::new(HashMap::new())),
buffer,
monitor: SessionMonitor::new(),
inactivity_detector: InactivityDetector::new(),
signal_handler: SignalHandler::new(),
event_sender,
stats: Arc::new(RwLock::new(ExtractionStats::default())),
active: Arc::new(RwLock::new(false)),
})
}
pub async fn with_hook(self, hook: Box<dyn AgentHook>) -> Result<Self> {
let agent_type = hook.agent_type().to_string();
let event_sender = self.event_sender.clone();
let agent_type_clone = agent_type.clone();
let _callback = Arc::new(move |ctx: SessionContext| {
let _ = event_sender.send(ExtractionEvent::Completed {
agent_type: agent_type_clone.clone(),
source: ExtractionSource::NativeHook("session_end".to_string()),
context: Box::new(ctx),
});
});
{
let mut hooks = self.hooks.write().await;
hooks.insert(agent_type.clone(), hook);
}
Ok(self)
}
pub fn subscribe(&self) -> broadcast::Receiver<ExtractionEvent> {
self.event_sender.subscribe()
}
pub async fn start(&self) -> Result<()> {
let mut active = self.active.write().await;
if *active {
return Ok(());
}
*active = true;
drop(active);
let agent_types: Vec<String> = {
let hooks = self.hooks.read().await;
hooks.keys().cloned().collect()
};
let agent_types_enum: Vec<AgentType> = agent_types
.iter()
.filter_map(|s| AgentType::parse(s))
.collect();
self.monitor.start_monitoring(agent_types_enum).await;
self.inactivity_detector
.start_monitoring(agent_types.clone())
.await;
self.signal_handler.install().await?;
let event_sender = self.event_sender.clone();
let stats = self.stats.clone();
let mut monitor_rx = self.monitor.subscribe();
tokio::spawn(async move {
while let Ok(event) = monitor_rx.recv().await {
match event {
MonitorEvent::SessionEnded {
agent_type,
reason: _,
..
} => {
let _ = event_sender.send(ExtractionEvent::Started {
agent_type: agent_type.clone(),
source: ExtractionSource::ProcessMonitor,
});
let mut stats = stats.write().await;
stats.total_extractions += 1;
stats.monitor_extractions += 1;
}
MonitorEvent::InactivityDetected { agent_type, .. } => {
let _ = event_sender.send(ExtractionEvent::Started {
agent_type: agent_type.clone(),
source: ExtractionSource::InactivityTimeout,
});
let mut stats = stats.write().await;
stats.total_extractions += 1;
stats.inactivity_extractions += 1;
}
_ => {}
}
}
});
let _event_sender = self.event_sender.clone();
let stats = self.stats.clone();
let mut signal_rx = self.signal_handler.subscribe();
tokio::spawn(async move {
while let Ok(signal) = signal_rx.recv().await {
let _source = match signal {
SignalEvent::Interrupt => ExtractionSource::SignalHandler("SIGINT".to_string()),
SignalEvent::Terminate => {
ExtractionSource::SignalHandler("SIGTERM".to_string())
}
_ => continue,
};
let mut stats = stats.write().await;
stats.total_extractions += 1;
stats.signal_extractions += 1;
}
});
for agent_type in &agent_types {
self.buffer.start_buffering(agent_type).await?;
}
tracing::info!("Multi-layer extractor started");
Ok(())
}
pub async fn stop(&self) -> Result<()> {
let mut active = self.active.write().await;
*active = false;
self.monitor.stop_monitoring().await;
self.inactivity_detector.stop_monitoring().await;
self.buffer.flush_all().await?;
tracing::info!("Multi-layer extractor stopped");
Ok(())
}
pub async fn extract(&self, agent_type: &str) -> Result<SessionContext> {
let native_result = self.try_native_extraction(agent_type).await;
if let Ok(context) = native_result {
self.buffer
.buffer_context(agent_type, context.clone(), "extraction")
.await?;
return Ok(context);
}
if let Some(data) = self.buffer.recover_buffer(agent_type).await? {
let context = self.buffer_data_to_context(data);
let _ = self.event_sender.send(ExtractionEvent::BufferRecovered {
agent_type: agent_type.to_string(),
entries: context.insights.len(), });
let mut stats = self.stats.write().await;
stats.buffer_recoveries += 1;
return Ok(context);
}
Ok(SessionContext::new(agent_type)
.with_source("fallback")
.with_reliability(0.5))
}
async fn try_native_extraction(&self, agent_type: &str) -> Result<SessionContext> {
let hooks = self.hooks.read().await;
if let Some(hook) = hooks.get(agent_type) {
let activity = hook.detect_session_activity().await?;
if activity.is_active {
return hook.extract_session_context().await;
}
}
Err(HookError::SessionNotActive)
}
fn buffer_data_to_context(&self, data: BufferData) -> SessionContext {
let mut context = SessionContext::new(&data.agent_type)
.with_source("buffer_recovery")
.with_reliability(0.99);
for entry in data.entries {
context.insights.push(format!(
"[{}] {:?}",
entry.context_type,
entry.context.to_memory_content()
));
}
context
}
pub async fn stats(&self) -> ExtractionStats {
self.stats.read().await.clone()
}
pub async fn is_active(&self) -> bool {
*self.active.read().await
}
pub async fn trigger_extraction(&self, agent_type: &str) -> Result<HookResult> {
let _ = self.event_sender.send(ExtractionEvent::Started {
agent_type: agent_type.to_string(),
source: ExtractionSource::Manual,
});
match self.extract(agent_type).await {
Ok(context) => {
let _ = self.event_sender.send(ExtractionEvent::Completed {
agent_type: agent_type.to_string(),
source: ExtractionSource::Manual,
context: Box::new(context.clone()),
});
let mut stats = self.stats.write().await;
stats.total_extractions += 1;
Ok(HookResult::success_with_context(
agent_type,
ExtractionSource::Manual,
context,
))
}
Err(e) => {
let _ = self.event_sender.send(ExtractionEvent::Failed {
agent_type: agent_type.to_string(),
source: ExtractionSource::Manual,
error: e.to_string(),
});
let mut stats = self.stats.write().await;
stats.total_extractions += 1;
stats.failed_extractions += 1;
Ok(HookResult::failure(
agent_type,
ExtractionSource::Manual,
e.to_string(),
))
}
}
}
pub async fn check_for_recovery(&self) -> Result<Vec<(String, BufferData)>> {
let hooks = self.hooks.read().await;
let mut recovered = Vec::new();
for agent_type in hooks.keys() {
if let Some(data) = self.buffer.recover_buffer(agent_type).await? {
recovered.push((agent_type.clone(), data));
}
}
Ok(recovered)
}
pub async fn clear_buffer(&self, agent_type: &str) -> Result<()> {
self.buffer.clear_buffer(agent_type).await
}
}
impl Default for MultiLayerExtractor {
fn default() -> Self {
Self::new().expect("Failed to create extractor")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_extractor_new() {
let extractor = MultiLayerExtractor::new().unwrap();
assert!(!extractor.is_active().await);
}
#[tokio::test]
async fn test_extractor_stats() {
let extractor = MultiLayerExtractor::new().unwrap();
let stats = extractor.stats().await;
assert_eq!(stats.total_extractions, 0);
assert_eq!(stats.success_rate(), 1.0);
}
#[tokio::test]
async fn test_extractor_subscribe() {
let extractor = MultiLayerExtractor::new().unwrap();
let receiver = extractor.subscribe();
drop(receiver);
}
#[test]
fn test_extraction_stats_success_rate() {
let mut stats = ExtractionStats::default();
assert_eq!(stats.success_rate(), 1.0);
stats.total_extractions = 10;
stats.failed_extractions = 2;
assert!((stats.success_rate() - 0.8).abs() < 0.001);
}
}