mod applier;
mod processor;
mod sink;
mod subscriber;
pub use applier::{Applier, ApplierConfig, ApplierError, ApplyMode, ApplyResult};
pub use processor::{ProcessResult, Processor, ProcessorConfig, ProcessorError, ProcessorMode};
pub use sink::{DataSink, DataSinkError, DataSinkStats};
pub use subscriber::{ActionEventSubscriber, EventSubscriberConfig, LearningEventSubscriber};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::interval;
use crate::learn::learn_model::{LearnModel, WorkerDecisionSequenceLearn};
use crate::learn::lora::{
LoraTrainer, LoraTrainerConfig, ModelApplicator, NoOpApplicator, TrainedModel,
};
use crate::learn::record::{DependencyGraphRecord, LearnStatsRecord, Record};
use crate::learn::snapshot::LearningStore;
use crate::learn::store::{
EpisodeStore, FileEpisodeStore, FileRecordStore, InMemoryEpisodeStore, InMemoryRecordStore,
RecordStore, RecordStoreError, StoreError,
};
use crate::learn::trigger::{TrainTrigger, TriggerBuilder, TriggerContext};
use crate::learn::LearnStats;
use crate::util::epoch_millis;
#[derive(Debug)]
pub enum DaemonError {
Sink(DataSinkError),
Processor(ProcessorError),
Applier(ApplierError),
Io(std::io::Error),
Config(String),
Shutdown,
}
impl std::fmt::Display for DaemonError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Sink(e) => write!(f, "Sink error: {}", e),
Self::Processor(e) => write!(f, "Processor error: {}", e),
Self::Applier(e) => write!(f, "Applier error: {}", e),
Self::Io(e) => write!(f, "IO error: {}", e),
Self::Config(msg) => write!(f, "Config error: {}", msg),
Self::Shutdown => write!(f, "Daemon shutdown"),
}
}
}
impl std::error::Error for DaemonError {}
impl From<DataSinkError> for DaemonError {
fn from(e: DataSinkError) -> Self {
Self::Sink(e)
}
}
impl From<ProcessorError> for DaemonError {
fn from(e: ProcessorError) -> Self {
Self::Processor(e)
}
}
impl From<ApplierError> for DaemonError {
fn from(e: ApplierError) -> Self {
Self::Applier(e)
}
}
impl From<std::io::Error> for DaemonError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl From<RecordStoreError> for DaemonError {
fn from(e: RecordStoreError) -> Self {
Self::Sink(DataSinkError::RecordStore(e))
}
}
impl From<StoreError> for DaemonError {
fn from(e: StoreError) -> Self {
Self::Sink(DataSinkError::EpisodeStore(e))
}
}
#[derive(Debug, Clone)]
pub struct DaemonConfig {
pub scenario: String,
pub data_dir: PathBuf,
pub check_interval: Duration,
pub processor_mode: ProcessorMode,
pub max_sessions: usize,
pub auto_apply: bool,
pub lora_config: Option<LoraTrainerConfig>,
}
impl DaemonConfig {
pub fn new(scenario: impl Into<String>) -> Self {
Self {
scenario: scenario.into(),
data_dir: default_data_dir(),
check_interval: Duration::from_secs(10),
processor_mode: ProcessorMode::OfflineOnly,
max_sessions: 20,
auto_apply: false,
lora_config: None,
}
}
pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.data_dir = path.into();
self
}
pub fn check_interval(mut self, interval: Duration) -> Self {
self.check_interval = interval;
self
}
pub fn processor_mode(mut self, mode: ProcessorMode) -> Self {
self.processor_mode = mode;
self
}
pub fn max_sessions(mut self, n: usize) -> Self {
self.max_sessions = n;
self
}
pub fn auto_apply(mut self, enabled: bool) -> Self {
self.auto_apply = enabled;
self
}
pub fn with_lora(mut self, config: LoraTrainerConfig) -> Self {
self.lora_config = Some(config);
self
}
}
fn default_data_dir() -> PathBuf {
dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("swarm-engine")
.join("learning")
}
#[derive(Debug, Clone, Default)]
pub struct DaemonStats {
pub records_received: usize,
pub episodes_created: usize,
pub trainings_completed: usize,
pub models_applied: usize,
pub last_train_at: Option<u64>,
pub started_at: u64,
}
pub struct LearningDaemon {
config: DaemonConfig,
sink: DataSink,
trigger: Arc<dyn TrainTrigger>,
processor: Processor,
applier: Option<Applier>,
learning_store: LearningStore,
stats: DaemonStats,
last_train_count: usize,
record_rx: mpsc::Receiver<Vec<Record>>,
record_tx: mpsc::Sender<Vec<Record>>,
shutdown_rx: mpsc::Receiver<()>,
shutdown_tx: mpsc::Sender<()>,
}
impl LearningDaemon {
pub fn new(config: DaemonConfig, trigger: Arc<dyn TrainTrigger>) -> Result<Self, DaemonError> {
let record_store: Arc<dyn RecordStore> = Arc::new(InMemoryRecordStore::new());
let episode_store: Arc<dyn EpisodeStore> = Arc::new(InMemoryEpisodeStore::new());
let learn_model: Arc<dyn LearnModel> = Arc::new(WorkerDecisionSequenceLearn::new());
Self::with_stores(config, trigger, record_store, episode_store, learn_model)
}
pub fn with_file_stores(
config: DaemonConfig,
trigger: Arc<dyn TrainTrigger>,
) -> Result<Self, DaemonError> {
std::fs::create_dir_all(&config.data_dir)?;
let record_store: Arc<dyn RecordStore> =
Arc::new(FileRecordStore::new(config.data_dir.join("records"))?);
let episode_store: Arc<dyn EpisodeStore> =
Arc::new(FileEpisodeStore::new(config.data_dir.join("episodes"))?);
let learn_model: Arc<dyn LearnModel> = Arc::new(WorkerDecisionSequenceLearn::new());
Self::with_stores(config, trigger, record_store, episode_store, learn_model)
}
pub fn with_stores(
config: DaemonConfig,
trigger: Arc<dyn TrainTrigger>,
record_store: Arc<dyn RecordStore>,
episode_store: Arc<dyn EpisodeStore>,
learn_model: Arc<dyn LearnModel>,
) -> Result<Self, DaemonError> {
let sink = DataSink::new(
record_store,
Arc::clone(&episode_store),
Arc::clone(&learn_model),
);
let processor_config = ProcessorConfig::new(&config.scenario)
.mode(config.processor_mode)
.max_sessions(config.max_sessions);
let mut processor = Processor::new(processor_config);
let learning_store = LearningStore::new(&config.data_dir)?;
let learning_store_for_processor = LearningStore::new(&config.data_dir)?;
processor = processor.with_learning_store(learning_store_for_processor);
if let Some(lora_config) = &config.lora_config {
let trainer = LoraTrainer::new(lora_config.clone(), episode_store);
processor = processor
.with_lora_trainer(trainer)
.with_learn_model(learn_model);
}
let applier = if config.auto_apply {
let applier_config = ApplierConfig::default().auto_apply();
let applicator: Arc<dyn ModelApplicator> = Arc::new(NoOpApplicator::new());
Some(Applier::new(applier_config, applicator))
} else {
None
};
let (record_tx, record_rx) = mpsc::channel(1000);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
Ok(Self {
config,
sink,
trigger,
processor,
applier,
learning_store,
stats: DaemonStats {
started_at: epoch_millis(),
..Default::default()
},
last_train_count: 0,
record_rx,
record_tx,
shutdown_rx,
shutdown_tx,
})
}
pub fn record_sender(&self) -> mpsc::Sender<Vec<Record>> {
self.record_tx.clone()
}
pub fn shutdown_sender(&self) -> mpsc::Sender<()> {
self.shutdown_tx.clone()
}
pub fn config(&self) -> &DaemonConfig {
&self.config
}
pub fn stats(&self) -> &DaemonStats {
&self.stats
}
pub async fn run(&mut self) -> Result<(), DaemonError> {
tracing::info!(
scenario = %self.config.scenario,
data_dir = %self.config.data_dir.display(),
trigger = self.trigger.name(),
"Learning daemon started"
);
let mut check_interval = interval(self.config.check_interval);
loop {
tokio::select! {
_ = self.shutdown_rx.recv() => {
tracing::info!("Shutdown signal received, draining remaining records...");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
while let Ok(records) = self.record_rx.try_recv() {
if let Err(e) = self.handle_records(records).await {
tracing::warn!("Error processing records during shutdown: {}", e);
}
}
tracing::info!(
records_received = self.stats.records_received,
episodes_created = self.stats.episodes_created,
"Shutdown complete"
);
return Ok(());
}
Some(records) = self.record_rx.recv() => {
self.handle_records(records).await?;
}
_ = check_interval.tick() => {
self.check_and_train().await?;
}
}
}
}
async fn handle_records(&mut self, records: Vec<Record>) -> Result<(), DaemonError> {
if records.is_empty() {
return Ok(());
}
let count = records.len();
for record in &records {
match record {
Record::LearnStats(stats_record) => {
self.save_stats_to_learning_store(stats_record);
}
Record::DependencyGraph(dep_graph_record) => {
self.save_dependency_graph_to_learning_store(dep_graph_record);
}
_ => {}
}
}
let episode_ids = self.sink.ingest(records)?;
self.stats.records_received += count;
self.stats.episodes_created += episode_ids.len();
tracing::debug!(
records = count,
episodes = episode_ids.len(),
"Processed records"
);
Ok(())
}
fn save_stats_to_learning_store(&self, stats_record: &LearnStatsRecord) {
use crate::learn::snapshot::{LearningSnapshot, SnapshotMetadata, SNAPSHOT_VERSION};
use crate::learn::{EpisodeTransitions, NgramStats, SelectionPerformance};
use crate::online_stats::ActionStats;
use std::collections::HashMap;
let learn_stats: Option<LearnStats> = serde_json::from_str(&stats_record.stats_json).ok();
let metadata = SnapshotMetadata {
scenario_name: Some(stats_record.scenario.clone()),
task_description: None,
created_at: stats_record.timestamp_ms / 1000, session_count: 1,
total_episodes: 1,
total_actions: stats_record.total_actions as u32,
phase: None,
group_id: None,
};
let (
episode_transitions,
action_stats,
ngram_stats,
selection_performance,
contextual_stats,
) = if let Some(ref stats) = learn_stats {
let transitions = stats.episode_transitions.clone();
let ngram = stats.ngram_stats.clone();
let selection = stats.selection_performance.clone();
let mut ctx_stats: HashMap<(String, String), ActionStats> = HashMap::new();
for ((prev, action), ctx) in &stats.contextual_stats {
ctx_stats.insert(
(prev.clone(), action.clone()),
ActionStats {
visits: ctx.visits,
successes: ctx.successes,
failures: ctx.failures,
..Default::default()
},
);
}
let action_stats: HashMap<String, ActionStats> = HashMap::new();
(transitions, action_stats, ngram, selection, ctx_stats)
} else {
(
EpisodeTransitions::default(),
HashMap::new(),
NgramStats::default(),
SelectionPerformance::default(),
HashMap::new(),
)
};
let snapshot = LearningSnapshot {
version: SNAPSHOT_VERSION,
metadata,
episode_transitions,
ngram_stats,
selection_performance,
contextual_stats,
action_stats,
};
match self
.learning_store
.save_session(&stats_record.scenario, &snapshot)
{
Ok(session_id) => {
tracing::info!(
scenario = %stats_record.scenario,
session_id = %session_id.0,
success = stats_record.is_success(),
"Saved session to LearningStore"
);
}
Err(e) => {
tracing::warn!(
scenario = %stats_record.scenario,
error = %e,
"Failed to save session to LearningStore"
);
}
}
}
fn save_dependency_graph_to_learning_store(&self, record: &DependencyGraphRecord) {
use crate::learn::{ActionOrderSource, LearnedActionOrder};
let all_actions: Vec<String> = record
.discover_order
.iter()
.chain(record.not_discover_order.iter())
.cloned()
.collect();
let action_set_hash = LearnedActionOrder::compute_hash(&all_actions);
let action_order = LearnedActionOrder {
discover: record.discover_order.clone(),
not_discover: record.not_discover_order.clone(),
action_set_hash,
source: ActionOrderSource::Llm,
lora: None,
validated_accuracy: None,
};
let scenario = &self.config.scenario;
let model_result = self.learning_store.load_offline_model(scenario);
let updated_model = match model_result {
Ok(mut model) => {
model.action_order = Some(action_order.clone());
model
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
crate::learn::OfflineModel {
action_order: Some(action_order.clone()),
..Default::default()
}
}
Err(e) => {
tracing::warn!(
scenario = %scenario,
error = %e,
"Failed to load OfflineModel for action_order update"
);
return;
}
};
match self
.learning_store
.save_offline_model(scenario, &updated_model)
{
Ok(()) => {
tracing::info!(
scenario = %scenario,
discover = ?action_order.discover,
not_discover = ?action_order.not_discover,
action_set_hash = action_order.action_set_hash,
"Saved action_order to OfflineModel"
);
}
Err(e) => {
tracing::warn!(
scenario = %scenario,
error = %e,
"Failed to save action_order to OfflineModel"
);
}
}
}
async fn check_and_train(&mut self) -> Result<(), DaemonError> {
let current_count = self.sink.episode_count();
let ctx = TriggerContext::with_count(current_count)
.last_train_at(self.stats.last_train_at.unwrap_or(0))
.last_train_count(self.last_train_count);
if !self.trigger.should_train(&ctx).unwrap_or(false) {
return Ok(());
}
tracing::info!(
episode_count = current_count,
trigger = self.trigger.name(),
"Trigger fired, starting learning"
);
let result = self
.processor
.run(self.sink.episode_store().as_ref())
.await?;
self.stats.trainings_completed += 1;
self.stats.last_train_at = Some(epoch_millis());
self.last_train_count = current_count;
if let Some(applier) = &mut self.applier {
if let Some(model) = result.lora_model() {
let apply_result = applier.apply(model).await?;
if apply_result.is_applied() {
self.stats.models_applied += 1;
}
}
}
tracing::info!(
trainings = self.stats.trainings_completed,
models_applied = self.stats.models_applied,
"Learning cycle completed"
);
Ok(())
}
pub async fn train_now(&mut self) -> Result<ProcessResult, DaemonError> {
tracing::info!("Manual training triggered");
let result = self
.processor
.run(self.sink.episode_store().as_ref())
.await?;
self.stats.trainings_completed += 1;
self.stats.last_train_at = Some(epoch_millis());
self.last_train_count = self.sink.episode_count();
Ok(result)
}
pub async fn apply_model(&mut self, model: &TrainedModel) -> Result<ApplyResult, DaemonError> {
let applier = self
.applier
.as_mut()
.ok_or_else(|| DaemonError::Config("Applier not configured".into()))?;
let result = applier.apply_now(model).await?;
if result.is_applied() {
self.stats.models_applied += 1;
}
Ok(result)
}
}
pub struct DaemonBuilder {
config: DaemonConfig,
trigger: Option<Arc<dyn TrainTrigger>>,
record_store: Option<Arc<dyn RecordStore>>,
episode_store: Option<Arc<dyn EpisodeStore>>,
learn_model: Option<Arc<dyn LearnModel>>,
applicator: Option<Arc<dyn ModelApplicator>>,
}
impl DaemonBuilder {
pub fn new(scenario: impl Into<String>) -> Self {
Self {
config: DaemonConfig::new(scenario),
trigger: None,
record_store: None,
episode_store: None,
learn_model: None,
applicator: None,
}
}
pub fn data_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.config.data_dir = path.into();
self
}
pub fn trigger(mut self, trigger: Arc<dyn TrainTrigger>) -> Self {
self.trigger = Some(trigger);
self
}
pub fn processor_mode(mut self, mode: ProcessorMode) -> Self {
self.config.processor_mode = mode;
self
}
pub fn auto_apply(mut self) -> Self {
self.config.auto_apply = true;
self
}
pub fn record_store(mut self, store: Arc<dyn RecordStore>) -> Self {
self.record_store = Some(store);
self
}
pub fn episode_store(mut self, store: Arc<dyn EpisodeStore>) -> Self {
self.episode_store = Some(store);
self
}
pub fn learn_model(mut self, model: Arc<dyn LearnModel>) -> Self {
self.learn_model = Some(model);
self
}
pub fn applicator(mut self, applicator: Arc<dyn ModelApplicator>) -> Self {
self.applicator = Some(applicator);
self
}
pub fn with_lora(mut self, config: LoraTrainerConfig) -> Self {
self.config.lora_config = Some(config);
self
}
pub fn build(self) -> Result<LearningDaemon, DaemonError> {
let trigger = self
.trigger
.unwrap_or_else(|| TriggerBuilder::default_watch());
let record_store = self
.record_store
.unwrap_or_else(|| Arc::new(InMemoryRecordStore::new()));
let episode_store = self
.episode_store
.unwrap_or_else(|| Arc::new(InMemoryEpisodeStore::new()));
let learn_model = self
.learn_model
.unwrap_or_else(|| Arc::new(WorkerDecisionSequenceLearn::new()));
LearningDaemon::with_stores(
self.config,
trigger,
record_store,
episode_store,
learn_model,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
use crate::learn::trigger::AlwaysTrigger;
use crate::types::WorkerId;
fn make_test_records(count: usize) -> Vec<Record> {
(0..count)
.map(|i| {
let event = ActionEventBuilder::new(i as u64, WorkerId(0), format!("Action{}", i))
.result(ActionEventResult::success())
.duration(std::time::Duration::from_millis(10))
.context(ActionContext::new())
.build();
Record::from(&event)
})
.collect()
}
#[test]
fn test_daemon_config_builder() {
let config = DaemonConfig::new("test")
.data_dir("/tmp/test")
.check_interval(Duration::from_secs(30))
.processor_mode(ProcessorMode::Full)
.auto_apply(true);
assert_eq!(config.scenario, "test");
assert_eq!(config.data_dir, PathBuf::from("/tmp/test"));
assert_eq!(config.check_interval, Duration::from_secs(30));
assert_eq!(config.processor_mode, ProcessorMode::Full);
assert!(config.auto_apply);
}
#[tokio::test]
async fn test_daemon_creation() {
let config = DaemonConfig::new("test");
let trigger = TriggerBuilder::never();
let daemon = LearningDaemon::new(config, trigger).unwrap();
assert_eq!(daemon.config().scenario, "test");
assert_eq!(daemon.stats().records_received, 0);
}
#[tokio::test]
async fn test_daemon_record_ingestion() {
let config = DaemonConfig::new("test");
let trigger = TriggerBuilder::never();
let mut daemon = LearningDaemon::new(config, trigger).unwrap();
let sender = daemon.record_sender();
let records = make_test_records(5);
sender.send(records).await.unwrap();
daemon.handle_records(make_test_records(3)).await.unwrap();
assert_eq!(daemon.stats().records_received, 3);
}
#[tokio::test]
async fn test_daemon_builder() {
let daemon = DaemonBuilder::new("test-scenario")
.data_dir("/tmp/test")
.trigger(Arc::new(AlwaysTrigger))
.processor_mode(ProcessorMode::OfflineOnly)
.build()
.unwrap();
assert_eq!(daemon.config().scenario, "test-scenario");
}
}