swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Event sink trait and implementations.

use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;

use super::WatchEvent;
use crate::error::SwarmError;
use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};

/// Trait for event sinks (terminal processors).
pub trait EventSink: Send {
    /// Event type consumed by this sink.
    type Event: Send;

    /// Process an event.
    fn process(
        &mut self,
        event: Self::Event,
    ) -> impl std::future::Future<Output = Result<(), SwarmError>> + Send;
}

/// Learning sink - triggers offline learning when events arrive.
///
/// Uses `spawn_blocking` internally to run synchronous `LearningStore::run_offline_learning`
/// without blocking the async runtime.
///
/// ## Trigger Integration
///
/// By default, learning runs on every event (AlwaysTrigger).
/// Use `with_trigger()` to customize when learning runs:
///
/// ```ignore
/// let sink = LearningSink::new(path, 20)
///     .with_trigger(TriggerBuilder::every_n_episodes(10));
/// ```
pub struct LearningSink {
    learning_path: Arc<PathBuf>,
    max_sessions: usize,
    /// Trigger for deciding when to run learning
    trigger: Arc<dyn TrainTrigger>,
    /// Number of events received since last training
    event_count: usize,
    /// Timestamp of last training (for TimeTrigger)
    last_train_at: Option<Instant>,
    /// Event count at last training
    last_train_count: usize,
}

impl LearningSink {
    /// Create a new learning sink.
    ///
    /// By default, runs learning on every event (AlwaysTrigger).
    ///
    /// # Arguments
    /// * `learning_path` - Path to learning data directory
    /// * `max_sessions` - Maximum sessions to analyze
    pub fn new(learning_path: PathBuf, max_sessions: usize) -> Self {
        Self {
            learning_path: Arc::new(learning_path),
            max_sessions,
            trigger: Arc::new(AlwaysTrigger),
            event_count: 0,
            last_train_at: None,
            last_train_count: 0,
        }
    }

    /// Set a custom trigger for controlling when learning runs.
    ///
    /// # Example
    /// ```ignore
    /// use swarm_engine_core::learn::TriggerBuilder;
    ///
    /// // Run learning every 10 events
    /// let sink = LearningSink::new(path, 20)
    ///     .with_trigger(TriggerBuilder::every_n_episodes(10));
    ///
    /// // Run learning every 5 minutes OR when 50 events accumulated
    /// let sink = LearningSink::new(path, 20)
    ///     .with_trigger(Arc::new(OrTrigger::new(vec![
    ///         TriggerBuilder::every_minutes(5),
    ///         TriggerBuilder::every_n_episodes(50),
    ///     ])));
    /// ```
    pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
        self.trigger = trigger;
        self
    }

    /// Get the learning path.
    pub fn learning_path(&self) -> &PathBuf {
        &self.learning_path
    }

    /// Get the current event count.
    pub fn event_count(&self) -> usize {
        self.event_count
    }

    /// Check if training should run based on the trigger.
    fn should_train(&self) -> bool {
        // Note: TimeTrigger uses Unix timestamp (ms), but we track Instant internally.
        // For simplicity, we don't fully support TimeTrigger via LearningSink.
        // Full TimeTrigger support requires LearnProcess with EpisodeStore.
        // CountTrigger and AlwaysTrigger work correctly.
        let ctx =
            TriggerContext::with_count(self.event_count).last_train_count(self.last_train_count);

        // Ignore errors - if trigger fails, don't train
        self.trigger.should_train(&ctx).unwrap_or(false)
    }

    /// Mark that training was performed.
    fn mark_trained(&mut self) {
        self.last_train_at = Some(Instant::now());
        self.last_train_count = self.event_count;
    }
}

impl EventSink for LearningSink {
    type Event = WatchEvent;

    async fn process(&mut self, event: Self::Event) -> Result<(), SwarmError> {
        // Increment event counter
        self.event_count += 1;

        // Check trigger condition
        if !self.should_train() {
            tracing::debug!(
                scenario = %event.scenario,
                event_count = self.event_count,
                trigger = self.trigger.name(),
                "Trigger not met, skipping learning"
            );
            return Ok(());
        }

        tracing::info!(
            scenario = %event.scenario,
            event_count = self.event_count,
            trigger = self.trigger.name(),
            "Trigger condition met, running offline learning"
        );

        let learning_path = Arc::clone(&self.learning_path);
        let scenario = event.scenario.clone();
        let max_sessions = self.max_sessions;

        // Run synchronous LearningStore operations in a blocking task
        let result = tokio::task::spawn_blocking(move || {
            use crate::learn::LearningStore;

            let store = LearningStore::new(&*learning_path)?;
            store.run_offline_learning(&scenario, max_sessions)
        })
        .await;

        match result {
            Ok(Ok(model)) => {
                tracing::info!(
                    scenario = %event.scenario,
                    sessions = model.analyzed_sessions,
                    "Offline learning completed"
                );
                // Mark training completed
                self.mark_trained();
                Ok(())
            }
            Ok(Err(e)) => {
                tracing::warn!(
                    scenario = %event.scenario,
                    error = %e,
                    "Offline learning failed"
                );
                // Don't propagate - continue processing other events
                // Don't mark as trained on failure
                Ok(())
            }
            Err(e) => {
                tracing::error!(
                    scenario = %event.scenario,
                    error = %e,
                    "Blocking task panicked"
                );
                Ok(())
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::atomic::{AtomicUsize, Ordering};

    /// Counting sink for testing.
    pub struct CountingSink {
        count: Arc<AtomicUsize>,
    }

    impl CountingSink {
        pub fn new() -> Self {
            Self {
                count: Arc::new(AtomicUsize::new(0)),
            }
        }

        pub fn count(&self) -> usize {
            self.count.load(Ordering::SeqCst)
        }
    }

    impl EventSink for CountingSink {
        type Event = WatchEvent;

        async fn process(&mut self, _event: Self::Event) -> Result<(), SwarmError> {
            self.count.fetch_add(1, Ordering::SeqCst);
            Ok(())
        }
    }

    #[tokio::test]
    async fn test_counting_sink() {
        let mut sink = CountingSink::new();
        assert_eq!(sink.count(), 0);

        sink.process(WatchEvent::new("test".into())).await.unwrap();
        assert_eq!(sink.count(), 1);

        sink.process(WatchEvent::new("test2".into())).await.unwrap();
        assert_eq!(sink.count(), 2);
    }

    #[test]
    fn test_learning_sink_creation() {
        let sink = LearningSink::new(PathBuf::from("/tmp/test"), 20);
        assert_eq!(sink.learning_path().to_str().unwrap(), "/tmp/test");
    }
}