use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use super::WatchEvent;
use crate::error::SwarmError;
use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};
pub trait EventSink: Send {
type Event: Send;
fn process(
&mut self,
event: Self::Event,
) -> impl std::future::Future<Output = Result<(), SwarmError>> + Send;
}
pub struct LearningSink {
learning_path: Arc<PathBuf>,
max_sessions: usize,
trigger: Arc<dyn TrainTrigger>,
event_count: usize,
last_train_at: Option<Instant>,
last_train_count: usize,
}
impl LearningSink {
pub fn new(learning_path: PathBuf, max_sessions: usize) -> Self {
Self {
learning_path: Arc::new(learning_path),
max_sessions,
trigger: Arc::new(AlwaysTrigger),
event_count: 0,
last_train_at: None,
last_train_count: 0,
}
}
pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
self.trigger = trigger;
self
}
pub fn learning_path(&self) -> &PathBuf {
&self.learning_path
}
pub fn event_count(&self) -> usize {
self.event_count
}
fn should_train(&self) -> bool {
let ctx =
TriggerContext::with_count(self.event_count).last_train_count(self.last_train_count);
self.trigger.should_train(&ctx).unwrap_or(false)
}
fn mark_trained(&mut self) {
self.last_train_at = Some(Instant::now());
self.last_train_count = self.event_count;
}
}
impl EventSink for LearningSink {
type Event = WatchEvent;
async fn process(&mut self, event: Self::Event) -> Result<(), SwarmError> {
self.event_count += 1;
if !self.should_train() {
tracing::debug!(
scenario = %event.scenario,
event_count = self.event_count,
trigger = self.trigger.name(),
"Trigger not met, skipping learning"
);
return Ok(());
}
tracing::info!(
scenario = %event.scenario,
event_count = self.event_count,
trigger = self.trigger.name(),
"Trigger condition met, running offline learning"
);
let learning_path = Arc::clone(&self.learning_path);
let scenario = event.scenario.clone();
let max_sessions = self.max_sessions;
let result = tokio::task::spawn_blocking(move || {
use crate::learn::LearningStore;
let store = LearningStore::new(&*learning_path)?;
store.run_offline_learning(&scenario, max_sessions)
})
.await;
match result {
Ok(Ok(model)) => {
tracing::info!(
scenario = %event.scenario,
sessions = model.analyzed_sessions,
"Offline learning completed"
);
self.mark_trained();
Ok(())
}
Ok(Err(e)) => {
tracing::warn!(
scenario = %event.scenario,
error = %e,
"Offline learning failed"
);
Ok(())
}
Err(e) => {
tracing::error!(
scenario = %event.scenario,
error = %e,
"Blocking task panicked"
);
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
pub struct CountingSink {
count: Arc<AtomicUsize>,
}
impl CountingSink {
pub fn new() -> Self {
Self {
count: Arc::new(AtomicUsize::new(0)),
}
}
pub fn count(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
}
impl EventSink for CountingSink {
type Event = WatchEvent;
async fn process(&mut self, _event: Self::Event) -> Result<(), SwarmError> {
self.count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[tokio::test]
async fn test_counting_sink() {
let mut sink = CountingSink::new();
assert_eq!(sink.count(), 0);
sink.process(WatchEvent::new("test".into())).await.unwrap();
assert_eq!(sink.count(), 1);
sink.process(WatchEvent::new("test2".into())).await.unwrap();
assert_eq!(sink.count(), 2);
}
#[test]
fn test_learning_sink_creation() {
let sink = LearningSink::new(PathBuf::from("/tmp/test"), 20);
assert_eq!(sink.learning_path().to_str().unwrap(), "/tmp/test");
}
}