type LearnCallback = Box<dyn Fn(&str, &SwarmState) + Send + Sync>;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::learn::{AlwaysTrigger, TrainTrigger, TriggerContext};
use crate::orchestrator::SwarmResult;
use crate::state::SwarmState;
#[derive(Debug, Clone)]
pub enum LifecycleEvent {
Started {
worker_count: usize,
},
Terminated {
result: SwarmResult,
stats: TerminationStats,
},
}
#[derive(Debug, Clone, Default)]
pub struct TerminationStats {
pub total_ticks: u64,
pub total_actions: u64,
pub successful_actions: u64,
pub failed_actions: u64,
pub scenario: Option<String>,
pub group_id: Option<String>,
}
impl TerminationStats {
pub fn from_state(state: &SwarmState) -> Self {
Self {
total_ticks: state.shared.tick,
total_actions: state.shared.stats.total_visits() as u64,
successful_actions: state.shared.stats.total_successes() as u64,
failed_actions: state.shared.stats.total_failures() as u64,
scenario: None,
group_id: None,
}
}
pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
self.scenario = Some(scenario.into());
self
}
pub fn with_group_id(mut self, group_id: impl Into<String>) -> Self {
self.group_id = Some(group_id.into());
self
}
}
pub trait LifecycleHook: Send + Sync {
fn on_start(&mut self, worker_count: usize) {
let _ = worker_count;
}
fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult);
fn name(&self) -> &str {
"lifecycle_hook"
}
}
pub struct LearningLifecycleHook {
learning_path: PathBuf,
trigger: Arc<dyn TrainTrigger>,
eval_count: Arc<AtomicUsize>,
last_learn_count: usize,
scenario: Option<String>,
learn_callback: Option<LearnCallback>,
}
impl LearningLifecycleHook {
pub fn new(learning_path: impl Into<PathBuf>) -> Self {
Self {
learning_path: learning_path.into(),
trigger: Arc::new(AlwaysTrigger),
eval_count: Arc::new(AtomicUsize::new(0)),
last_learn_count: 0,
scenario: None,
learn_callback: None,
}
}
pub fn with_trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
self.trigger = trigger;
self
}
pub fn with_scenario(mut self, scenario: impl Into<String>) -> Self {
self.scenario = Some(scenario.into());
self
}
pub fn with_learn_callback<F>(mut self, callback: F) -> Self
where
F: Fn(&str, &SwarmState) + Send + Sync + 'static,
{
self.learn_callback = Some(Box::new(callback));
self
}
pub fn eval_count_handle(&self) -> Arc<AtomicUsize> {
Arc::clone(&self.eval_count)
}
pub fn with_shared_eval_count(mut self, count: Arc<AtomicUsize>) -> Self {
self.eval_count = count;
self
}
pub fn current_eval_count(&self) -> usize {
self.eval_count.load(Ordering::SeqCst)
}
pub fn learning_path(&self) -> &PathBuf {
&self.learning_path
}
fn should_learn(&self) -> bool {
let current = self.eval_count.load(Ordering::SeqCst);
let ctx = TriggerContext::with_count(current).last_train_count(self.last_learn_count);
self.trigger.should_train(&ctx).unwrap_or(false)
}
fn run_learn(&mut self, state: &SwarmState) {
let scenario = self.scenario.as_deref().unwrap_or("unknown");
tracing::info!(
scenario = scenario,
eval_count = self.current_eval_count(),
trigger = self.trigger.name(),
"Running learning after trigger condition met"
);
if let Some(ref callback) = self.learn_callback {
callback(scenario, state);
} else {
self.run_default_learn(scenario);
}
self.last_learn_count = self.eval_count.load(Ordering::SeqCst);
}
fn run_default_learn(&self, scenario: &str) {
use crate::learn::LearningStore;
match LearningStore::new(&self.learning_path) {
Ok(store) => match store.run_offline_learning(scenario, 20) {
Ok(model) => {
tracing::info!(
scenario = scenario,
sessions = model.analyzed_sessions,
"Offline learning completed"
);
}
Err(e) => {
tracing::warn!(
scenario = scenario,
error = %e,
"Offline learning failed"
);
}
},
Err(e) => {
tracing::error!(
path = %self.learning_path.display(),
error = %e,
"Failed to create LearningStore"
);
}
}
}
}
impl LifecycleHook for LearningLifecycleHook {
fn on_start(&mut self, worker_count: usize) {
tracing::debug!(
worker_count = worker_count,
eval_count = self.current_eval_count(),
"LearningLifecycleHook: Swarm started"
);
}
fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
let new_count = self.eval_count.fetch_add(1, Ordering::SeqCst) + 1;
tracing::debug!(
eval_count = new_count,
total_ticks = result.total_ticks,
trigger = self.trigger.name(),
"LearningLifecycleHook: Swarm terminated"
);
if self.should_learn() {
self.run_learn(state);
} else {
tracing::debug!(
eval_count = new_count,
last_learn = self.last_learn_count,
trigger = self.trigger.name(),
"Trigger not met, skipping learning"
);
}
}
fn name(&self) -> &str {
"learning_lifecycle_hook"
}
}
pub struct CompositeLifecycleHook {
hooks: Vec<Box<dyn LifecycleHook>>,
}
impl CompositeLifecycleHook {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
pub fn with_hook(mut self, hook: Box<dyn LifecycleHook>) -> Self {
self.hooks.push(hook);
self
}
}
impl Default for CompositeLifecycleHook {
fn default() -> Self {
Self::new()
}
}
impl LifecycleHook for CompositeLifecycleHook {
fn on_start(&mut self, worker_count: usize) {
for hook in &mut self.hooks {
hook.on_start(worker_count);
}
}
fn on_terminate(&mut self, state: &SwarmState, result: &SwarmResult) {
for hook in &mut self.hooks {
hook.on_terminate(state, result);
}
}
fn name(&self) -> &str {
"composite_lifecycle_hook"
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
struct TestHook {
started: Arc<AtomicBool>,
terminated: Arc<AtomicBool>,
}
impl TestHook {
fn new() -> (Self, Arc<AtomicBool>, Arc<AtomicBool>) {
let started = Arc::new(AtomicBool::new(false));
let terminated = Arc::new(AtomicBool::new(false));
(
Self {
started: Arc::clone(&started),
terminated: Arc::clone(&terminated),
},
started,
terminated,
)
}
}
impl LifecycleHook for TestHook {
fn on_start(&mut self, _worker_count: usize) {
self.started.store(true, Ordering::SeqCst);
}
fn on_terminate(&mut self, _state: &SwarmState, _result: &SwarmResult) {
self.terminated.store(true, Ordering::SeqCst);
}
}
#[test]
fn test_termination_stats_from_state() {
let state = SwarmState::new(4);
let stats = TerminationStats::from_state(&state);
assert_eq!(stats.total_ticks, 0);
assert!(stats.scenario.is_none());
}
#[test]
fn test_learning_lifecycle_hook_eval_count() {
let hook = LearningLifecycleHook::new("/tmp/test");
assert_eq!(hook.current_eval_count(), 0);
let handle = hook.eval_count_handle();
handle.fetch_add(5, Ordering::SeqCst);
assert_eq!(hook.current_eval_count(), 5);
}
#[test]
fn test_composite_hook() {
let (hook1, started1, terminated1) = TestHook::new();
let (hook2, started2, terminated2) = TestHook::new();
let mut composite = CompositeLifecycleHook::new()
.with_hook(Box::new(hook1))
.with_hook(Box::new(hook2));
composite.on_start(4);
assert!(started1.load(Ordering::SeqCst));
assert!(started2.load(Ordering::SeqCst));
let state = SwarmState::new(4);
let result = SwarmResult {
total_ticks: 10,
total_duration: std::time::Duration::from_secs(1),
completed: true,
};
composite.on_terminate(&state, &result);
assert!(terminated1.load(Ordering::SeqCst));
assert!(terminated2.load(Ordering::SeqCst));
}
}