use crate::transformer::QCT;
const SHIFT: f32 = std::f32::consts::FRAC_PI_2;
#[derive(Clone, Debug)]
pub struct TrainConfig {
pub learning_rate: f32,
pub num_epochs: usize,
pub context_length: usize,
pub log_interval: usize,
pub grad_clip: f32,
pub use_cosine_decay: bool,
pub warmup_epochs: usize,
}
impl Default for TrainConfig {
fn default() -> Self {
Self {
learning_rate: 0.03, num_epochs: 100,
context_length: 64, log_interval: 10,
grad_clip: 4.236, use_cosine_decay: false,
warmup_epochs: 0,
}
}
}
#[derive(Clone, Debug)]
pub struct EpochMetrics {
pub epoch: usize,
pub loss: f32,
pub free_energy: f32,
pub grad_norm: f32,
pub elapsed_ms: f32,
pub learning_rate: f32,
pub params_trained: usize,
}
fn learning_rate(config: &TrainConfig, epoch: usize) -> f32 {
learning_rate_pub(config, epoch)
}
pub fn learning_rate_pub(config: &TrainConfig, epoch: usize) -> f32 {
let base_lr = config.learning_rate;
if config.warmup_epochs > 0 && epoch < config.warmup_epochs {
return base_lr * (epoch + 1) as f32 / config.warmup_epochs as f32;
}
if config.use_cosine_decay {
let effective_epoch = epoch.saturating_sub(config.warmup_epochs);
let total = config.num_epochs.saturating_sub(config.warmup_epochs).max(1);
let progress = effective_epoch as f32 / total as f32;
return base_lr * 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
}
base_lr
}
pub fn train(model: &mut QCT, tokens: &[usize], config: &TrainConfig) -> Vec<EpochMetrics> {
let mut metrics = Vec::new();
let num_params = model.num_params();
for epoch in 0..config.num_epochs {
let start = std::time::Instant::now();
let lr = learning_rate(config, epoch);
let max_start = tokens.len().saturating_sub(config.context_length + 1);
let window_start = if max_start > 0 { epoch % max_start } else { 0 };
let window_end = (window_start + config.context_length + 1).min(tokens.len());
let window = &tokens[window_start..window_end];
let base_loss = model.loss(window);
let (_, base_free_energy) = model.forward(&window[..window.len() - 1]);
let all_params = model.all_params();
let window_vec: Vec<usize> = window.to_vec();
use rayon::prelude::*;
let mut gradients: Vec<f32> = (0..num_params)
.into_par_iter()
.map(|k| {
let mut local = model.clone();
let mut plus = all_params.clone();
plus[k] += SHIFT;
local.set_all_params(&plus);
let loss_plus = local.loss(&window_vec);
plus[k] = all_params[k] - SHIFT;
local.set_all_params(&plus);
let loss_minus = local.loss(&window_vec);
(loss_plus - loss_minus) / 2.0
})
.collect();
let grad_norm: f32 = gradients.iter().map(|g| g * g).sum::<f32>().sqrt();
if grad_norm > config.grad_clip && grad_norm > 0.0 {
let scale = config.grad_clip / grad_norm;
for g in &mut gradients {
*g *= scale;
}
}
let mut updated = all_params;
for k in 0..num_params {
updated[k] -= lr * gradients[k];
}
model.set_all_params(&updated);
let elapsed = start.elapsed().as_secs_f32() * 1000.0;
if epoch % config.log_interval == 0 || epoch == config.num_epochs - 1 {
let m = EpochMetrics {
epoch,
loss: base_loss,
free_energy: base_free_energy,
grad_norm,
elapsed_ms: elapsed,
learning_rate: lr,
params_trained: num_params,
};
log::info!(
"Epoch {:4}: loss={:.4} F={:.4} |∇|={:.6} lr={:.5} params={} ({:.0}ms)",
m.epoch,
m.loss,
m.free_energy,
m.grad_norm,
m.learning_rate,
m.params_trained,
m.elapsed_ms
);
metrics.push(m);
}
}
metrics
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transformer::QCTConfig;
#[test]
fn training_reduces_loss() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let mut model = QCT::new(config);
let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5];
let initial_loss = model.loss(&tokens[..8]);
let train_config = TrainConfig {
learning_rate: 0.05,
num_epochs: 5,
context_length: 8,
log_interval: 5,
..Default::default()
};
let _metrics = train(&mut model, &tokens, &train_config);
let final_loss = model.loss(&tokens[..8]);
assert!(final_loss.is_finite(), "loss should be finite after training");
eprintln!("Initial loss: {:.4}, Final loss: {:.4}", initial_loss, final_loss);
}
#[test]
fn gradient_is_nonzero() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let mut model = QCT::new(config);
let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let train_config = TrainConfig {
learning_rate: 0.01,
num_epochs: 1,
context_length: 6,
log_interval: 1,
..Default::default()
};
let metrics = train(&mut model, &tokens, &train_config);
assert!(!metrics.is_empty());
assert!(metrics[0].grad_norm > 0.0, "gradient should be nonzero");
}
#[test]
fn all_params_trained() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let mut model = QCT::new(config);
let tokens: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let train_config = TrainConfig {
learning_rate: 0.01,
num_epochs: 1,
context_length: 6,
log_interval: 1,
..Default::default()
};
let metrics = train(&mut model, &tokens, &train_config);
assert_eq!(
metrics[0].params_trained,
model.num_params(),
"should train ALL {} params, not a subset",
model.num_params()
);
}
#[test]
fn all_params_roundtrip() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let model = QCT::new(config.clone());
let params = model.all_params();
let mut model2 = QCT::new(config);
model2.set_all_params(¶ms);
let params2 = model2.all_params();
assert_eq!(params.len(), params2.len());
for (a, b) in params.iter().zip(params2.iter()) {
assert!((a - b).abs() < 1e-6, "param roundtrip mismatch");
}
}
#[test]
fn cosine_lr_schedule() {
let config = TrainConfig {
learning_rate: 0.1,
num_epochs: 100,
use_cosine_decay: true,
..Default::default()
};
let lr_start = learning_rate(&config, 0);
let lr_mid = learning_rate(&config, 50);
let lr_end = learning_rate(&config, 99);
assert!((lr_start - 0.1).abs() < 0.01, "start lr should be ~0.1");
assert!((lr_mid - 0.05).abs() < 0.01, "mid lr should be ~0.05");
assert!(lr_end < 0.01, "end lr should be near 0, got {lr_end}");
}
#[test]
fn warmup_lr_schedule() {
let config = TrainConfig {
learning_rate: 0.1,
num_epochs: 100,
warmup_epochs: 10,
..Default::default()
};
let lr_0 = learning_rate(&config, 0);
let lr_5 = learning_rate(&config, 5);
let lr_10 = learning_rate(&config, 10);
assert!(lr_0 < lr_5, "lr should increase during warmup");
assert!(lr_5 < lr_10, "lr should increase during warmup");
assert!((lr_10 - 0.1).abs() < 0.01, "lr should reach base after warmup");
}
}