use crate::error::{Error, Result};
use crate::model::{
DefaultMultiscreenModel, ModelTrainingConfig, ModelTrainingReport, MultiscreenModelConfig,
};
use crate::runtime::{Device, default_device};
use std::fs;
use std::path::Path;
pub use crate::model::MultiscreenParameterBudget as ParameterBudget;
#[derive(Clone, Debug)]
pub struct TrainingReport {
pub steps: usize,
pub final_loss: f32,
pub best_loss: f32,
pub best_loss_step: usize,
pub parameter_count: usize,
pub checkpoint_path: Option<String>,
}
impl TrainingReport {
fn from_model_report(report: &ModelTrainingReport, checkpoint_path: Option<String>) -> Self {
Self {
steps: report.steps,
final_loss: report.final_loss,
best_loss: report.best_loss,
best_loss_step: report.best_loss_step,
parameter_count: report.parameter_count,
checkpoint_path,
}
}
}
pub struct Trainer {
model: DefaultMultiscreenModel,
training_config: ModelTrainingConfig,
checkpoint_dir: Option<String>,
checkpoint_interval: usize,
#[allow(dead_code)]
run_dir: Option<String>,
}
impl std::fmt::Debug for Trainer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Trainer")
.field("training_config", &self.training_config)
.field("checkpoint_dir", &self.checkpoint_dir)
.field("checkpoint_interval", &self.checkpoint_interval)
.field("run_dir", &self.run_dir)
.finish_non_exhaustive()
}
}
impl Trainer {
pub fn builder() -> TrainerBuilder {
TrainerBuilder::new()
}
pub fn train_on_token_sequences_with_callback(
&mut self,
sequences: &[Vec<u32>],
on_step: impl FnMut(usize, f32),
) -> Result<TrainingReport> {
if sequences.is_empty() {
return Err(Error::Training("no training sequences provided".into()));
}
let mut config = self.training_config.clone();
config.checkpoint_dir = self.checkpoint_dir.clone();
config.checkpoint_interval = self.checkpoint_interval;
let device = self.model_device();
let report = self
.model
.train_token_sequences(sequences, &config, &device, on_step)?;
let checkpoint_path = match &self.checkpoint_dir {
Some(dir) => {
let dir_path = Path::new(dir);
fs::create_dir_all(dir_path).map_err(|e| {
Error::Io(format!(
"failed to create checkpoint directory {:?}: {}",
dir, e
))
})?;
let path = dir_path.join("checkpoint.mpk");
self.model.save_parameters(&path)?;
Some(path.to_string_lossy().into_owned())
}
None => None,
};
Ok(TrainingReport::from_model_report(&report, checkpoint_path))
}
pub fn train_on_token_sequences(&mut self, sequences: &[Vec<u32>]) -> Result<TrainingReport> {
self.train_on_token_sequences_with_callback(sequences, |_, _| {})
}
pub fn train_on_chat_sequences_with_callback(
&mut self,
chat_pairs: &[(Vec<u32>, Vec<u32>)],
on_step: impl FnMut(usize, f32),
) -> Result<TrainingReport> {
if chat_pairs.is_empty() {
return Err(Error::Training("no training chat pairs provided".into()));
}
let mut config = self.training_config.clone();
config.checkpoint_dir = self.checkpoint_dir.clone();
config.checkpoint_interval = self.checkpoint_interval;
let device = self.model_device();
let report = self
.model
.train_chat_sequences(chat_pairs, &config, &device, on_step)?;
let checkpoint_path = match &self.checkpoint_dir {
Some(dir) => {
let dir_path = Path::new(dir);
fs::create_dir_all(dir_path).map_err(|e| {
Error::Io(format!(
"failed to create checkpoint directory {:?}: {}",
dir, e
))
})?;
let path = dir_path.join("checkpoint.mpk");
self.model.save_parameters(&path)?;
Some(path.to_string_lossy().into_owned())
}
None => None,
};
Ok(TrainingReport::from_model_report(&report, checkpoint_path))
}
pub fn train_on_chat_sequences(
&mut self,
chat_pairs: &[(Vec<u32>, Vec<u32>)],
) -> Result<TrainingReport> {
self.train_on_chat_sequences_with_callback(chat_pairs, |_, _| {})
}
pub fn save_checkpoint(&self, path: &str) -> Result<()> {
if let Some(parent) = Path::new(path).parent() {
fs::create_dir_all(parent).map_err(|e| {
Error::Io(format!(
"failed to create checkpoint directory {:?}: {}",
parent, e
))
})?;
}
self.model.save_parameters(path)
}
pub fn model(&self) -> &DefaultMultiscreenModel {
&self.model
}
pub fn model_mut(&mut self) -> &mut DefaultMultiscreenModel {
&mut self.model
}
pub fn training_config(&self) -> &ModelTrainingConfig {
&self.training_config
}
fn model_device(&self) -> Device {
Device::default()
}
}
pub struct TrainerBuilder {
vocab_size: Option<usize>,
budget: ParameterBudget,
device: Option<Device>,
batch_size: usize,
seq_len: usize,
steps: usize,
learning_rate: f64,
weight_decay: f64,
grad_clip_norm: Option<f64>,
checkpoint_dir: Option<String>,
checkpoint_interval: usize,
run_dir: Option<String>,
}
impl TrainerBuilder {
fn new() -> Self {
Self {
vocab_size: None,
budget: ParameterBudget::Params10M,
device: None,
batch_size: 4,
seq_len: 128,
steps: 1000,
learning_rate: 2e-4,
weight_decay: 0.01,
grad_clip_norm: Some(1.0),
checkpoint_dir: None,
checkpoint_interval: 1000,
run_dir: None,
}
}
pub fn vocab_size(mut self, size: usize) -> Self {
self.vocab_size = Some(size);
self
}
pub fn budget(mut self, budget: ParameterBudget) -> Self {
self.budget = budget;
self
}
pub fn device(mut self, device: Device) -> Self {
self.device = Some(device);
self
}
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn seq_len(mut self, len: usize) -> Self {
self.seq_len = len;
self
}
pub fn steps(mut self, steps: usize) -> Self {
self.steps = steps;
self
}
pub fn learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = lr;
self
}
pub fn weight_decay(mut self, wd: f64) -> Self {
self.weight_decay = wd;
self
}
pub fn grad_clip_norm(mut self, norm: Option<f64>) -> Self {
self.grad_clip_norm = norm;
self
}
pub fn checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
self.checkpoint_dir = Some(dir.into());
self
}
pub fn checkpoint_interval(mut self, steps: usize) -> Self {
self.checkpoint_interval = steps;
self
}
pub fn run_dir(mut self, dir: impl Into<String>) -> Self {
self.run_dir = Some(dir.into());
self
}
pub fn build(self) -> Result<Trainer> {
let vocab_size = self.vocab_size.ok_or_else(|| {
Error::Config("vocab_size is required; call .vocab_size(n) before .build()".into())
})?;
let device = match self.device {
Some(d) => d,
None => default_device()?,
};
let config =
MultiscreenModelConfig::for_parameter_budget(self.budget, vocab_size, self.seq_len);
let model = DefaultMultiscreenModel::new(config, &device)?;
let training_config = ModelTrainingConfig {
steps: self.steps,
batch_size: self.batch_size,
learning_rate: self.learning_rate,
weight_decay: self.weight_decay,
grad_clip_norm: self.grad_clip_norm,
pad_token_id: 0,
checkpoint_dir: None, checkpoint_interval: 0,
};
let run_dir = self.run_dir.or_else(|| Some("runs/latest".to_string()));
Ok(Trainer {
model,
training_config,
checkpoint_dir: self.checkpoint_dir,
checkpoint_interval: self.checkpoint_interval,
run_dir,
})
}
}
impl Default for TrainerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_requires_vocab_size() {
let result = Trainer::builder().build();
assert!(result.is_err(), "build should fail without vocab_size");
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("vocab_size"),
"error should mention vocab_size: {}",
msg
);
}
#[test]
fn training_report_from_model_report() {
let model_report = ModelTrainingReport {
steps: 500,
final_loss: 0.123,
best_loss: 0.100,
best_loss_step: 420,
training_window_count: 100,
parameter_count: 10_000_000,
};
let report =
TrainingReport::from_model_report(&model_report, Some("runs/checkpoint.mpk".into()));
assert_eq!(report.steps, 500);
assert!((report.final_loss - 0.123).abs() < f32::EPSILON);
assert!((report.best_loss - 0.100).abs() < f32::EPSILON);
assert_eq!(report.best_loss_step, 420);
assert_eq!(report.parameter_count, 10_000_000);
assert_eq!(
report.checkpoint_path.as_deref(),
Some("runs/checkpoint.mpk")
);
}
#[test]
fn builder_defaults() {
let builder = TrainerBuilder::new();
assert!(builder.vocab_size.is_none());
assert!(matches!(builder.budget, ParameterBudget::Params10M));
assert!(builder.device.is_none());
assert_eq!(builder.batch_size, 4);
assert_eq!(builder.seq_len, 128);
assert_eq!(builder.steps, 1000);
assert!((builder.learning_rate - 2e-4).abs() < f64::EPSILON);
assert_eq!(builder.checkpoint_interval, 1000);
}
}