use std::io::Write as IoWrite;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use tokio::process::Command;
use crate::learn::episode::{EpisodeId, Outcome};
use crate::learn::learn_model::LearnModel;
use crate::learn::store::{EpisodeDto, EpisodeFilter, EpisodeStore, StoreError};
use crate::learn::training::TrainingData;
use crate::util::{epoch_millis, epoch_millis_for_ordering};
#[derive(Debug, Clone)]
pub struct LoraTrainerConfig {
pub base_model: String,
pub lora_rank: u32,
pub lora_alpha: f32,
pub lora_dropout: f32,
pub epochs: u32,
pub batch_size: u32,
pub gradient_accumulation: u32,
pub learning_rate: f32,
pub max_seq_length: u32,
pub train_script: PathBuf,
pub output_dir: PathBuf,
pub data_dir: PathBuf,
pub python_path: PathBuf,
}
impl Default for LoraTrainerConfig {
fn default() -> Self {
Self {
base_model: "LiquidAI/LFM2.5-1.2B-Instruct".to_string(),
lora_rank: 16,
lora_alpha: 32.0,
lora_dropout: 0.05,
epochs: 3,
batch_size: 4,
gradient_accumulation: 4,
learning_rate: 2e-4,
max_seq_length: 2048,
train_script: PathBuf::from("lora/train.py"),
output_dir: PathBuf::from("lora/adapters"),
data_dir: PathBuf::from("lora/data"),
python_path: PathBuf::from("python3"),
}
}
}
impl LoraTrainerConfig {
pub fn base_model(mut self, model: impl Into<String>) -> Self {
self.base_model = model.into();
self
}
pub fn lora_rank(mut self, rank: u32) -> Self {
self.lora_rank = rank;
self
}
pub fn lora_alpha(mut self, alpha: f32) -> Self {
self.lora_alpha = alpha;
self
}
pub fn epochs(mut self, epochs: u32) -> Self {
self.epochs = epochs;
self
}
pub fn batch_size(mut self, size: u32) -> Self {
self.batch_size = size;
self
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn train_script(mut self, path: impl Into<PathBuf>) -> Self {
self.train_script = path.into();
self
}
pub fn output_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.output_dir = path.into();
self
}
pub fn python_path(mut self, path: impl Into<PathBuf>) -> Self {
self.python_path = path.into();
self
}
}
#[derive(Debug, Clone)]
pub struct TrainedModel {
pub id: LoraModelId,
pub base_model: String,
pub adapter_path: PathBuf,
pub learn_model_name: String,
pub episode_ids: Vec<EpisodeId>,
pub sample_count: usize,
pub created_at: u64,
pub metrics: Option<TrainingMetrics>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct LoraModelId(String);
impl LoraModelId {
pub fn new() -> Self {
use std::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
let ts = epoch_millis_for_ordering();
let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
Self(format!("lora-{}-{:08x}", ts, counter))
}
pub fn parse(s: &str) -> Self {
Self(s.to_string())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for LoraModelId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for LoraModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Default)]
pub struct TrainingMetrics {
pub final_loss: Option<f64>,
pub training_time_secs: Option<u64>,
pub gpu_memory_mb: Option<u64>,
}
#[derive(Debug)]
pub enum LoraTrainerError {
Store(StoreError),
EmptyData(String),
Io(std::io::Error),
ScriptNotFound(PathBuf),
ProcessFailed { exit_code: i32, stderr: String },
Other(String),
}
impl std::fmt::Display for LoraTrainerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Store(e) => write!(f, "Store error: {}", e),
Self::EmptyData(msg) => write!(f, "Empty data: {}", msg),
Self::Io(e) => write!(f, "IO error: {}", e),
Self::ScriptNotFound(p) => write!(f, "Script not found: {}", p.display()),
Self::ProcessFailed { exit_code, stderr } => {
write!(f, "Training failed (exit {}): {}", exit_code, stderr)
}
Self::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for LoraTrainerError {}
impl From<StoreError> for LoraTrainerError {
fn from(e: StoreError) -> Self {
Self::Store(e)
}
}
impl From<std::io::Error> for LoraTrainerError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
pub struct LoraTrainer {
config: LoraTrainerConfig,
episode_store: Arc<dyn EpisodeStore>,
}
impl LoraTrainer {
pub fn new(config: LoraTrainerConfig, episode_store: Arc<dyn EpisodeStore>) -> Self {
Self {
config,
episode_store,
}
}
pub fn config(&self) -> &LoraTrainerConfig {
&self.config
}
pub fn episode_store(&self) -> &Arc<dyn EpisodeStore> {
&self.episode_store
}
pub async fn train(
&self,
learn_model: &dyn LearnModel,
filter: Option<EpisodeFilter>,
) -> Result<TrainedModel, LoraTrainerError> {
let started_at = std::time::Instant::now();
tracing::info!(
learn_model = learn_model.name(),
"Fetching episodes for training"
);
let filter = filter.unwrap_or_default();
let episodes = self.episode_store.query(&filter)?;
if episodes.is_empty() {
return Err(LoraTrainerError::EmptyData(
"No episodes found for training".into(),
));
}
let episode_ids: Vec<_> = episodes.iter().map(|e| e.id.clone()).collect();
tracing::info!(episode_count = episodes.len(), "Episodes fetched");
tracing::info!("Converting episodes to training data");
let training_data: Vec<TrainingData> = episodes
.iter()
.filter_map(|ep| episode_dto_to_training_data(ep, learn_model.name()).ok())
.collect();
if training_data.is_empty() {
return Err(LoraTrainerError::EmptyData(
"No training data generated from episodes".into(),
));
}
let sample_count = training_data.len();
tracing::info!(sample_count, "Training data prepared");
let data_path = self.write_training_data(&training_data, learn_model.name())?;
tracing::info!(path = %data_path.display(), "Training data written");
let timestamp = epoch_millis() / 1000; let adapter_name = format!("{}-{}", learn_model.name(), timestamp);
let adapter_path = self.run_lora_training(&data_path, &adapter_name).await?;
let elapsed = started_at.elapsed();
tracing::info!(
elapsed_secs = elapsed.as_secs(),
adapter = %adapter_path.display(),
"Training completed"
);
let model = TrainedModel {
id: LoraModelId::new(),
base_model: self.config.base_model.clone(),
adapter_path,
learn_model_name: learn_model.name().to_string(),
episode_ids,
sample_count,
created_at: epoch_millis(),
metrics: Some(TrainingMetrics {
final_loss: None, training_time_secs: Some(elapsed.as_secs()),
gpu_memory_mb: None,
}),
};
Ok(model)
}
fn write_training_data(
&self,
data: &[TrainingData],
learn_model_name: &str,
) -> Result<PathBuf, LoraTrainerError> {
std::fs::create_dir_all(&self.config.data_dir)?;
let filename = format!("{}.jsonl", learn_model_name);
let path = self.config.data_dir.join(filename);
let mut file = std::fs::File::create(&path)?;
for td in data {
let json_str = training_data_to_json(td)?;
writeln!(file, "{}", json_str)?;
}
Ok(path)
}
async fn run_lora_training(
&self,
data_path: &Path,
adapter_name: &str,
) -> Result<PathBuf, LoraTrainerError> {
if !self.config.train_script.exists() {
return Err(LoraTrainerError::ScriptNotFound(
self.config.train_script.clone(),
));
}
let output_path = self.config.output_dir.join(adapter_name);
let mut cmd = Command::new(&self.config.python_path);
cmd.arg(&self.config.train_script)
.arg("--data")
.arg(data_path)
.arg("--output")
.arg(&output_path)
.arg("--model")
.arg(&self.config.base_model)
.arg("--rank")
.arg(self.config.lora_rank.to_string())
.arg("--alpha")
.arg(self.config.lora_alpha.to_string())
.arg("--dropout")
.arg(self.config.lora_dropout.to_string())
.arg("--epochs")
.arg(self.config.epochs.to_string())
.arg("--batch-size")
.arg(self.config.batch_size.to_string())
.arg("--grad-accum")
.arg(self.config.gradient_accumulation.to_string())
.arg("--lr")
.arg(self.config.learning_rate.to_string())
.arg("--max-seq-length")
.arg(self.config.max_seq_length.to_string())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
tracing::info!(
script = %self.config.train_script.display(),
data = %data_path.display(),
output = %output_path.display(),
"Starting LoRA training"
);
let output = cmd.output().await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(LoraTrainerError::ProcessFailed {
exit_code: output.status.code().unwrap_or(-1),
stderr: stderr.to_string(),
});
}
let stdout = String::from_utf8_lossy(&output.stdout);
for line in stdout.lines() {
tracing::debug!(line, "train.py output");
}
Ok(output_path)
}
}
fn episode_dto_to_training_data(
dto: &EpisodeDto,
learn_model_name: &str,
) -> Result<TrainingData, LoraTrainerError> {
let system_prompt = format!(
"You are an intelligent agent using the {} strategy. Your task is to make optimal decisions.",
learn_model_name
);
let user_prompt = format!(
"Episode ID: {}\nLearn Model: {}\nMetadata: {:?}",
dto.id, dto.learn_model, dto.metadata
);
let response = match &dto.outcome {
Outcome::Success { score } => {
format!("Decision successful with score {:.2}", score)
}
Outcome::Failure { reason } => {
format!("Decision failed: {}", reason)
}
Outcome::Timeout { partial_score } => match partial_score {
Some(score) => format!("Timeout with partial score {:.2}", score),
None => "Timeout without progress".to_string(),
},
Outcome::Unknown => "Outcome unknown".to_string(),
};
let training_data = TrainingData::sft(&system_prompt, &user_prompt, &response)
.with_episode_id(dto.id.to_string())
.with_model(learn_model_name);
let training_data = if let Outcome::Success { score } = &dto.outcome {
training_data.with_outcome_score(*score)
} else {
training_data
};
Ok(training_data)
}
fn training_data_to_json(td: &TrainingData) -> Result<String, LoraTrainerError> {
let conversation = td.to_conversation();
let turns: Vec<serde_json::Value> = conversation
.conversations
.iter()
.map(|turn| {
serde_json::json!({
"role": match turn.role {
crate::learn::training::ConversationRole::System => "system",
crate::learn::training::ConversationRole::User => "user",
crate::learn::training::ConversationRole::Assistant => "assistant",
},
"content": turn.content,
})
})
.collect();
let json_value = serde_json::json!({
"conversations": turns
});
serde_json::to_string(&json_value)
.map_err(|e| LoraTrainerError::Other(format!("JSON serialization error: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learn::store::InMemoryEpisodeStore;
#[test]
fn test_trainer_config_builder() {
let config = LoraTrainerConfig::default()
.base_model("test-model")
.lora_rank(32)
.lora_alpha(64.0)
.epochs(5)
.batch_size(8)
.learning_rate(1e-4);
assert_eq!(config.base_model, "test-model");
assert_eq!(config.lora_rank, 32);
assert_eq!(config.lora_alpha, 64.0);
assert_eq!(config.epochs, 5);
assert_eq!(config.batch_size, 8);
assert!((config.learning_rate - 1e-4).abs() < 1e-10);
}
#[test]
fn test_model_id() {
let id1 = LoraModelId::new();
let id2 = LoraModelId::new();
assert!(!id1.as_str().is_empty());
assert!(!id2.as_str().is_empty());
}
#[test]
fn test_trainer_creation() {
let config = LoraTrainerConfig::default();
let store = Arc::new(InMemoryEpisodeStore::new());
let trainer = LoraTrainer::new(config, store);
assert_eq!(trainer.config().base_model, "LiquidAI/LFM2.5-1.2B-Instruct");
assert_eq!(trainer.config().lora_rank, 16);
}
#[tokio::test]
async fn test_train_empty_store() {
use crate::learn::learn_model::WorkerTaskLearn;
let config = LoraTrainerConfig::default();
let store = Arc::new(InMemoryEpisodeStore::new());
let trainer = LoraTrainer::new(config, store);
let learn_model = WorkerTaskLearn::new();
let result = trainer.train(&learn_model, None).await;
assert!(result.is_err());
match result {
Err(LoraTrainerError::EmptyData(_)) => {}
_ => panic!("Expected EmptyData error"),
}
}
#[test]
fn test_episode_dto_to_training_data() {
use crate::learn::episode::EpisodeMetadata;
let dto = EpisodeDto {
id: EpisodeId::new(),
learn_model: "test".to_string(),
outcome: Outcome::success(0.95),
metadata: EpisodeMetadata::new(),
record_ids: vec![],
};
let td = episode_dto_to_training_data(&dto, "test-model").unwrap();
assert!(td.is_sft());
}
#[test]
fn test_training_data_to_json() {
let td = TrainingData::sft(
"You are a helpful assistant.",
"What is 2+2?",
"2+2 equals 4.",
);
let json = training_data_to_json(&td).unwrap();
assert!(json.contains("conversations"));
assert!(json.contains("system"));
assert!(json.contains("user"));
assert!(json.contains("assistant"));
}
}