use crate::error::{Result, RuvLLMError};
use crate::lora::adapters::{AdapterMetadata, LoraConfig};
use crate::lora::micro_lora::{AdaptFeedback, MicroLoRA};
use crate::lora::training::{LearningRateSchedule, TrainingConfig, TrainingPipeline};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingExample {
pub input: Vec<f32>,
pub target: Option<Vec<f32>>,
pub quality: f32,
pub task: Option<String>,
pub domain: Option<String>,
}
impl TrainingExample {
pub fn new(input: Vec<f32>, quality: f32) -> Self {
Self {
input,
target: None,
quality,
task: None,
domain: None,
}
}
pub fn with_target(mut self, target: Vec<f32>) -> Self {
self.target = Some(target);
self
}
pub fn with_task(mut self, task: impl Into<String>) -> Self {
self.task = Some(task.into());
self
}
pub fn with_domain(mut self, domain: impl Into<String>) -> Self {
self.domain = Some(domain.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdapterDataset {
pub examples: Vec<TrainingExample>,
pub validation: Vec<TrainingExample>,
pub name: String,
pub description: String,
pub feature_dim: usize,
}
impl AdapterDataset {
pub fn new(name: impl Into<String>, feature_dim: usize) -> Self {
Self {
examples: Vec::new(),
validation: Vec::new(),
name: name.into(),
description: String::new(),
feature_dim,
}
}
pub fn add_example(&mut self, example: TrainingExample) {
self.examples.push(example);
}
pub fn add_validation(&mut self, example: TrainingExample) {
self.validation.push(example);
}
pub fn split(&mut self, validation_ratio: f32) {
let total = self.examples.len();
let val_size = (total as f32 * validation_ratio) as usize;
if val_size > 0 && val_size < total {
let split_idx = total - val_size;
self.validation = self.examples.split_off(split_idx);
}
}
pub fn stats(&self) -> DatasetStats {
let avg_quality = self.examples.iter().map(|e| e.quality).sum::<f32>()
/ self.examples.len().max(1) as f32;
let val_avg_quality = if !self.validation.is_empty() {
self.validation.iter().map(|e| e.quality).sum::<f32>() / self.validation.len() as f32
} else {
0.0
};
DatasetStats {
train_size: self.examples.len(),
val_size: self.validation.len(),
feature_dim: self.feature_dim,
avg_quality,
val_avg_quality,
}
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let bytes = bincode::serde::encode_to_vec(self, bincode::config::standard())
.map_err(|e| RuvLLMError::Serialization(e.to_string()))?;
std::fs::write(path, bytes)?;
Ok(())
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let bytes = std::fs::read(path)?;
let (dataset, _): (Self, usize) =
bincode::serde::decode_from_slice(&bytes, bincode::config::standard())
.map_err(|e| RuvLLMError::Serialization(e.to_string()))?;
Ok(dataset)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetStats {
pub train_size: usize,
pub val_size: usize,
pub feature_dim: usize,
pub avg_quality: f32,
pub val_avg_quality: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdapterTrainingConfig {
pub training: TrainingConfig,
pub epochs: usize,
pub validation_interval: usize,
pub early_stopping_patience: usize,
pub min_improvement: f32,
pub gradient_checkpointing: bool,
pub mixed_precision: bool,
pub save_best: bool,
pub output_dir: String,
}
impl Default for AdapterTrainingConfig {
fn default() -> Self {
Self {
training: TrainingConfig::default(),
epochs: 3,
validation_interval: 100,
early_stopping_patience: 3,
min_improvement: 0.001,
gradient_checkpointing: true,
mixed_precision: false,
save_best: true,
output_dir: "./adapters".to_string(),
}
}
}
impl AdapterTrainingConfig {
pub fn quick() -> Self {
Self {
training: TrainingConfig {
learning_rate: 0.005,
lr_schedule: LearningRateSchedule::Constant,
..Default::default()
},
epochs: 1,
early_stopping_patience: 1,
..Default::default()
}
}
pub fn stable() -> Self {
Self {
training: TrainingConfig::stable(),
epochs: 5,
early_stopping_patience: 5,
min_improvement: 0.0001,
..Default::default()
}
}
}
pub struct AdapterTrainer {
config: AdapterTrainingConfig,
pipeline: TrainingPipeline,
best_val_loss: f32,
epochs_without_improvement: usize,
history: TrainingHistory,
}
impl AdapterTrainer {
pub fn new(config: AdapterTrainingConfig) -> Self {
let pipeline = TrainingPipeline::new(config.training.clone());
Self {
config,
pipeline,
best_val_loss: f32::MAX,
epochs_without_improvement: 0,
history: TrainingHistory::default(),
}
}
pub fn train(&mut self, lora: &MicroLoRA, dataset: &AdapterDataset) -> Result<TrainingResult> {
self.pipeline.init_for_lora(lora);
let mut best_loss = f32::MAX;
let mut global_step = 0;
for epoch in 0..self.config.epochs {
eprintln!("Epoch {}/{}", epoch + 1, self.config.epochs);
let mut epoch_loss = 0.0;
let mut num_batches = 0;
for example in &dataset.examples {
let feedback = AdaptFeedback::from_quality(example.quality);
self.pipeline.train_step(lora, &example.input, feedback)?;
epoch_loss += 1.0 - example.quality;
num_batches += 1;
global_step += 1;
if global_step % self.config.validation_interval == 0
&& !dataset.validation.is_empty()
{
let val_loss = self.validate(lora, &dataset.validation)?;
eprintln!(" Step {}: val_loss = {:.4}", global_step, val_loss);
self.history.val_losses.push(val_loss);
if val_loss < best_loss - self.config.min_improvement {
best_loss = val_loss;
self.epochs_without_improvement = 0;
if self.config.save_best {
self.save_checkpoint(lora, epoch, val_loss)?;
}
}
}
}
let avg_loss = epoch_loss / num_batches as f32;
self.history.train_losses.push(avg_loss);
eprintln!(" Avg train loss: {:.4}", avg_loss);
if !dataset.validation.is_empty() {
let val_loss = self.validate(lora, &dataset.validation)?;
eprintln!(" Validation loss: {:.4}", val_loss);
if val_loss < self.best_val_loss - self.config.min_improvement {
self.best_val_loss = val_loss;
self.epochs_without_improvement = 0;
} else {
self.epochs_without_improvement += 1;
}
if self.epochs_without_improvement >= self.config.early_stopping_patience {
eprintln!("Early stopping triggered after {} epochs", epoch + 1);
break;
}
}
self.pipeline.start_new_task(lora);
}
Ok(TrainingResult {
final_loss: self.history.train_losses.last().copied().unwrap_or(0.0),
best_val_loss: self.best_val_loss,
epochs_completed: self.history.train_losses.len(),
total_steps: global_step,
history: self.history.clone(),
})
}
fn validate(&self, lora: &MicroLoRA, validation: &[TrainingExample]) -> Result<f32> {
let mut total_loss = 0.0;
for example in validation {
total_loss += 1.0 - example.quality;
}
Ok(total_loss / validation.len() as f32)
}
fn save_checkpoint(&self, lora: &MicroLoRA, epoch: usize, val_loss: f32) -> Result<()> {
std::fs::create_dir_all(&self.config.output_dir)?;
let path = format!(
"{}/adapter_epoch{}_loss{:.4}.bin",
self.config.output_dir, epoch, val_loss
);
lora.save(&path)?;
eprintln!(" Saved checkpoint: {}", path);
Ok(())
}
pub fn history(&self) -> &TrainingHistory {
&self.history
}
pub fn reset(&mut self) {
self.best_val_loss = f32::MAX;
self.epochs_without_improvement = 0;
self.history = TrainingHistory::default();
self.pipeline.reset();
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrainingHistory {
pub train_losses: Vec<f32>,
pub val_losses: Vec<f32>,
pub learning_rates: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingResult {
pub final_loss: f32,
pub best_val_loss: f32,
pub epochs_completed: usize,
pub total_steps: usize,
pub history: TrainingHistory,
}
pub struct SyntheticDataGenerator {
feature_dim: usize,
seed: u64,
}
impl SyntheticDataGenerator {
pub fn new(feature_dim: usize, seed: u64) -> Self {
Self { feature_dim, seed }
}
pub fn generate(&self, task_type: &str, num_examples: usize) -> AdapterDataset {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
let mut rng = StdRng::seed_from_u64(self.seed);
let mut dataset = AdapterDataset::new(format!("{}_synthetic", task_type), self.feature_dim);
for _ in 0..num_examples {
let input: Vec<f32> = (0..self.feature_dim)
.map(|_| rng.gen_range(-1.0..1.0))
.collect();
let quality = match task_type {
"coder" => {
let structure_score = input
.iter()
.take(self.feature_dim / 4)
.map(|x| x.abs())
.sum::<f32>()
/ (self.feature_dim / 4) as f32;
(0.6 + structure_score * 0.4).min(1.0)
}
"researcher" => {
let density =
input.iter().map(|x| x.abs()).sum::<f32>() / self.feature_dim as f32;
(0.5 + density * 0.5).min(1.0)
}
"security" => {
let critical_score = input.iter().step_by(2).map(|x| x.abs()).sum::<f32>()
/ (self.feature_dim / 2) as f32;
(0.7 + critical_score * 0.3).min(1.0)
}
"architect" => {
let coherence = input.windows(2).map(|w| (w[0] - w[1]).abs()).sum::<f32>()
/ (self.feature_dim - 1) as f32;
(0.6 + (1.0 - coherence) * 0.4).min(1.0)
}
"reviewer" => {
let balance = 1.0 - (input.iter().sum::<f32>() / self.feature_dim as f32).abs();
(0.5 + balance * 0.5).min(1.0)
}
_ => rng.gen_range(0.5..1.0),
};
let example = TrainingExample::new(input, quality)
.with_task(task_type)
.with_domain(task_type);
dataset.add_example(example);
}
dataset.split(0.2);
dataset
}
pub fn generate_all(&self, examples_per_task: usize) -> Vec<(String, AdapterDataset)> {
vec![
(
"coder".to_string(),
self.generate("coder", examples_per_task),
),
(
"researcher".to_string(),
self.generate("researcher", examples_per_task),
),
(
"security".to_string(),
self.generate("security", examples_per_task),
),
(
"architect".to_string(),
self.generate("architect", examples_per_task),
),
(
"reviewer".to_string(),
self.generate("reviewer", examples_per_task),
),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lora::adapters::RuvLtraAdapters;
#[test]
fn test_training_example() {
let example = TrainingExample::new(vec![0.1; 64], 0.8)
.with_task("test")
.with_domain("testing");
assert_eq!(example.input.len(), 64);
assert_eq!(example.quality, 0.8);
assert_eq!(example.task, Some("test".to_string()));
}
#[test]
fn test_dataset_creation() {
let mut dataset = AdapterDataset::new("test", 64);
for i in 0..100 {
let example = TrainingExample::new(vec![i as f32; 64], 0.5 + i as f32 * 0.005);
dataset.add_example(example);
}
assert_eq!(dataset.examples.len(), 100);
}
#[test]
fn test_dataset_split() {
let mut dataset = AdapterDataset::new("test", 64);
for i in 0..100 {
let example = TrainingExample::new(vec![i as f32; 64], 0.8);
dataset.add_example(example);
}
dataset.split(0.2);
assert_eq!(dataset.examples.len(), 80);
assert_eq!(dataset.validation.len(), 20);
}
#[test]
fn test_synthetic_data_generator() {
let generator = SyntheticDataGenerator::new(64, 42);
let dataset = generator.generate("coder", 100);
assert_eq!(dataset.feature_dim, 64);
assert!(dataset.examples.len() > 0);
assert!(dataset.validation.len() > 0);
for example in &dataset.examples {
assert!(example.quality >= 0.0 && example.quality <= 1.0);
}
}
#[test]
fn test_adapter_trainer() {
let adapters = RuvLtraAdapters::new();
let lora = adapters.create_lora("coder", 64).unwrap();
let generator = SyntheticDataGenerator::new(64, 42);
let dataset = generator.generate("coder", 50);
let config = AdapterTrainingConfig::quick();
let mut trainer = AdapterTrainer::new(config);
let result = trainer.train(&lora, &dataset).unwrap();
assert!(result.epochs_completed > 0);
assert!(result.total_steps > 0);
}
#[test]
fn test_generate_all_datasets() {
let generator = SyntheticDataGenerator::new(64, 42);
let datasets = generator.generate_all(100);
assert_eq!(datasets.len(), 5);
for (name, dataset) in datasets {
assert!(dataset.examples.len() > 0);
println!(
"{}: {} train, {} val",
name,
dataset.examples.len(),
dataset.validation.len()
);
}
}
}