mod dependency_graph;
mod dpo;
mod error;
mod worker_decision;
mod worker_task;
pub use dependency_graph::DependencyGraphLearnModel;
pub use dpo::{DpoConfig, DpoLearnModel, DpoPair};
pub use error::LearnError;
pub use worker_decision::WorkerDecisionSequenceLearn;
pub use worker_task::WorkerTaskLearn;
use crate::events::ActionEvent;
use super::episode::{Episode, EpisodeContext, Outcome};
use super::record::Record;
use super::training::TrainingData;
pub mod system_events {
pub const TICK_START: &str = "tick_start";
pub const TICK_END: &str = "tick_end";
pub const DONE: &str = "done";
pub const DEFAULT_SYSTEM_EVENTS: &[&str] = &[TICK_START, TICK_END, DONE];
}
pub trait LearnModel: Send + Sync {
fn name(&self) -> &str;
fn objective(&self) -> &str;
fn build_episodes(&self, records: &[Record]) -> Vec<Episode>;
fn evaluate(&self, context: &EpisodeContext) -> Outcome;
fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError>;
fn convert_batch(&self, episodes: &[Episode]) -> Vec<TrainingData> {
episodes
.iter()
.filter_map(|ep| self.convert(ep).ok())
.collect()
}
fn build_episodes_from_actions(&self, actions: &[ActionEvent]) -> Vec<Episode> {
let records: Vec<Record> = actions.iter().map(Record::from).collect();
self.build_episodes(&records)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
use crate::learn::record::RecordStream;
use crate::types::WorkerId;
use std::time::Duration;
fn make_action(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
let result = if success {
ActionEventResult::success()
} else {
ActionEventResult::failure("error")
};
ActionEventBuilder::new(tick, WorkerId(worker_id), action)
.result(result)
.duration(Duration::from_millis(10))
.context(ActionContext::new())
.build()
}
fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
actions.iter().map(Record::from).collect()
}
#[test]
fn test_record_accessors() {
let action = make_action(1, 5, "CheckStatus", true);
let record = Record::from(&action);
assert!(record.is_action());
assert!(!record.is_llm());
assert_eq!(record.worker_id(), Some(5));
assert!(record.as_action().is_some());
assert!(record.as_llm().is_none());
}
#[test]
fn test_record_stream_group_by_worker() {
let actions = vec![
make_action(1, 0, "A", true),
make_action(2, 1, "B", true),
make_action(3, 0, "C", true),
make_action(4, 1, "D", true),
];
let records = make_records(&actions);
let stream = RecordStream::new(&records);
let groups = stream.group_by_worker();
assert_eq!(groups.len(), 2);
assert_eq!(groups.get(&0).map(|v| v.len()), Some(2));
assert_eq!(groups.get(&1).map(|v| v.len()), Some(2));
}
}