use std::sync::Arc;
use crate::learn::learn_model::LearnModel;
use crate::learn::lora::{LoraTrainer, LoraTrainerError, TrainedModel};
use crate::learn::offline::OfflineModel;
use crate::learn::snapshot::LearningStore;
use crate::learn::store::{EpisodeStore, StoreError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ProcessorMode {
#[default]
OfflineOnly,
LoraOnly,
Full,
}
impl std::str::FromStr for ProcessorMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"offline" | "offline_only" => Ok(Self::OfflineOnly),
"lora" | "lora_only" => Ok(Self::LoraOnly),
"full" | "both" => Ok(Self::Full),
_ => Err(format!("Unknown processor mode: {}", s)),
}
}
}
impl std::fmt::Display for ProcessorMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::OfflineOnly => write!(f, "offline"),
Self::LoraOnly => write!(f, "lora"),
Self::Full => write!(f, "full"),
}
}
}
#[derive(Debug)]
pub enum ProcessResult {
Offline(OfflineModel),
Lora(TrainedModel),
Full {
offline: OfflineModel,
lora: TrainedModel,
},
}
impl ProcessResult {
pub fn lora_model(&self) -> Option<&TrainedModel> {
match self {
Self::Lora(m) => Some(m),
Self::Full { lora, .. } => Some(lora),
Self::Offline(_) => None,
}
}
pub fn offline_model(&self) -> Option<&OfflineModel> {
match self {
Self::Offline(m) => Some(m),
Self::Full { offline, .. } => Some(offline),
Self::Lora(_) => None,
}
}
}
#[derive(Debug)]
pub enum ProcessorError {
Store(StoreError),
LoraTrainer(LoraTrainerError),
Io(std::io::Error),
InsufficientData(String),
Other(String),
}
impl std::fmt::Display for ProcessorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Store(e) => write!(f, "Store error: {}", e),
Self::LoraTrainer(e) => write!(f, "LoRA trainer error: {}", e),
Self::Io(e) => write!(f, "IO error: {}", e),
Self::InsufficientData(msg) => write!(f, "Insufficient data: {}", msg),
Self::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for ProcessorError {}
impl From<StoreError> for ProcessorError {
fn from(e: StoreError) -> Self {
Self::Store(e)
}
}
impl From<LoraTrainerError> for ProcessorError {
fn from(e: LoraTrainerError) -> Self {
Self::LoraTrainer(e)
}
}
impl From<std::io::Error> for ProcessorError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
#[derive(Debug, Clone)]
pub struct ProcessorConfig {
pub mode: ProcessorMode,
pub scenario: String,
pub max_sessions: usize,
}
impl Default for ProcessorConfig {
fn default() -> Self {
Self {
mode: ProcessorMode::OfflineOnly,
scenario: "default".to_string(),
max_sessions: 20,
}
}
}
impl ProcessorConfig {
pub fn new(scenario: impl Into<String>) -> Self {
Self {
scenario: scenario.into(),
..Default::default()
}
}
pub fn mode(mut self, mode: ProcessorMode) -> Self {
self.mode = mode;
self
}
pub fn max_sessions(mut self, n: usize) -> Self {
self.max_sessions = n;
self
}
}
pub struct Processor {
config: ProcessorConfig,
learning_store: Option<LearningStore>,
lora_trainer: Option<LoraTrainer>,
learn_model: Option<Arc<dyn LearnModel>>,
}
impl Processor {
pub fn new(config: ProcessorConfig) -> Self {
Self {
config,
learning_store: None,
lora_trainer: None,
learn_model: None,
}
}
pub fn with_learning_store(mut self, store: LearningStore) -> Self {
self.learning_store = Some(store);
self
}
pub fn with_lora_trainer(mut self, trainer: LoraTrainer) -> Self {
self.lora_trainer = Some(trainer);
self
}
pub fn with_learn_model(mut self, model: Arc<dyn LearnModel>) -> Self {
self.learn_model = Some(model);
self
}
pub fn config(&self) -> &ProcessorConfig {
&self.config
}
pub async fn run(
&self,
episode_store: &dyn EpisodeStore,
) -> Result<ProcessResult, ProcessorError> {
tracing::info!(
mode = %self.config.mode,
scenario = %self.config.scenario,
"Starting learning process"
);
match self.config.mode {
ProcessorMode::OfflineOnly => {
let model = self.run_offline()?;
Ok(ProcessResult::Offline(model))
}
ProcessorMode::LoraOnly => {
let model = self.run_lora(episode_store).await?;
Ok(ProcessResult::Lora(model))
}
ProcessorMode::Full => {
let offline = self.run_offline()?;
let lora = self.run_lora(episode_store).await?;
Ok(ProcessResult::Full { offline, lora })
}
}
}
fn run_offline(&self) -> Result<OfflineModel, ProcessorError> {
let store = self.learning_store.as_ref().ok_or_else(|| {
ProcessorError::Other("LearningStore not configured for offline analysis".into())
})?;
tracing::info!(
scenario = %self.config.scenario,
max_sessions = self.config.max_sessions,
"Running offline analysis"
);
let model = store.run_offline_learning(&self.config.scenario, self.config.max_sessions)?;
tracing::info!(
analyzed_sessions = model.analyzed_sessions,
ucb1_c = model.parameters.ucb1_c,
"Offline analysis completed"
);
Ok(model)
}
async fn run_lora(
&self,
episode_store: &dyn EpisodeStore,
) -> Result<TrainedModel, ProcessorError> {
let trainer = self
.lora_trainer
.as_ref()
.ok_or_else(|| ProcessorError::Other("LoraTrainer not configured".into()))?;
let learn_model = self.learn_model.as_ref().ok_or_else(|| {
ProcessorError::Other("LearnModel not configured for LoRA training".into())
})?;
let episode_count = episode_store.count(None)?;
if episode_count == 0 {
return Err(ProcessorError::InsufficientData(
"No episodes available for LoRA training".into(),
));
}
tracing::info!(
episode_count,
learn_model = learn_model.name(),
"Running LoRA training"
);
let model = trainer.train(learn_model.as_ref(), None).await?;
tracing::info!(
model_id = %model.id,
sample_count = model.sample_count,
"LoRA training completed"
);
Ok(model)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_processor_mode_from_str() {
assert_eq!(
"offline".parse::<ProcessorMode>().unwrap(),
ProcessorMode::OfflineOnly
);
assert_eq!(
"lora".parse::<ProcessorMode>().unwrap(),
ProcessorMode::LoraOnly
);
assert_eq!(
"full".parse::<ProcessorMode>().unwrap(),
ProcessorMode::Full
);
assert!("invalid".parse::<ProcessorMode>().is_err());
}
#[test]
fn test_processor_config_builder() {
let config = ProcessorConfig::new("test-scenario")
.mode(ProcessorMode::Full)
.max_sessions(50);
assert_eq!(config.scenario, "test-scenario");
assert_eq!(config.mode, ProcessorMode::Full);
assert_eq!(config.max_sessions, 50);
}
#[test]
fn test_process_result_accessors() {
let offline_model = OfflineModel::default();
let result = ProcessResult::Offline(offline_model);
assert!(result.offline_model().is_some());
assert!(result.lora_model().is_none());
}
}