use crate::models::llama_370m::Llama370MConfig;
use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
use crate::train::transformer_trainer::{LMBatch, TransformerTrainConfig, TransformerTrainer};
use crate::transformer::{ModelArchitecture, TransformerConfig};
use std::cell::RefCell;
use std::rc::Rc;
pub type SharedTrainer = Rc<RefCell<TransformerTrainer>>;
pub fn llama_370m_transformer_config() -> TransformerConfig {
TransformerConfig {
hidden_size: Llama370MConfig::HIDDEN_DIM,
num_attention_heads: Llama370MConfig::NUM_HEADS,
num_kv_heads: Llama370MConfig::NUM_KV_HEADS,
intermediate_size: Llama370MConfig::INTERMEDIATE_DIM,
num_hidden_layers: Llama370MConfig::NUM_LAYERS,
vocab_size: Llama370MConfig::VOCAB_SIZE,
max_position_embeddings: Llama370MConfig::MAX_POSITION_EMBEDDINGS,
rms_norm_eps: Llama370MConfig::RMS_NORM_EPS,
rope_theta: Llama370MConfig::ROPE_THETA,
use_bias: false,
head_dim_override: None,
architecture: ModelArchitecture::Decoder,
hf_architecture: Some("LlamaForCausalLM".into()),
hf_model_type: Some("llama".into()),
tie_word_embeddings: true,
}
}
pub fn llama_370m_train_config(lr: f32, seq_length: usize, seed: u64) -> TransformerTrainConfig {
let model_cfg = llama_370m_transformer_config();
let mut cfg = TransformerTrainConfig::new(model_cfg);
cfg.lr = lr;
cfg.max_seq_len = seq_length;
cfg.seed = seed;
cfg
}
pub struct RealStepFn {
trainer: SharedTrainer,
batches: Box<dyn Iterator<Item = LMBatch>>,
}
impl RealStepFn {
pub fn new(trainer: SharedTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
Self { trainer, batches }
}
}
impl StepFn for RealStepFn {
fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
let Some(batch) = self.batches.next() else {
return (1.0, 1.0);
};
let mut trainer = self.trainer.borrow_mut();
let loss = trainer.train_batch(&batch);
let grad_norm = 1.0_f32;
(loss, grad_norm)
}
fn optimizer_state_sha256(&self) -> Option<String> {
Some(self.trainer.borrow().optimizer_state_sha256())
}
}
pub struct RealValFn {
trainer: SharedTrainer,
held_out: Vec<LMBatch>,
}
impl RealValFn {
pub fn new(trainer: SharedTrainer, held_out: Vec<LMBatch>) -> Self {
Self { trainer, held_out }
}
}
impl ValFn for RealValFn {
fn validate(&mut self, _epoch: usize) -> f32 {
if self.held_out.is_empty() {
return f32::NAN;
}
let trainer = self.trainer.borrow();
let mut total_loss = 0.0_f32;
let mut total_items = 0_usize;
for batch in &self.held_out {
for i in 0..batch.batch_size {
let Some(inp) = batch.get_input(i) else {
continue;
};
let Some(tgt) = batch.get_target(i) else {
continue;
};
let (loss_val, _loss_tensor, _logits) = trainer.forward_single(inp, tgt);
total_loss += loss_val;
total_items += 1;
}
}
if total_items == 0 {
f32::NAN
} else {
total_loss / total_items as f32
}
}
}
pub struct AprCheckpointFn {
trainer: SharedTrainer,
model_name: String,
architecture: String,
}
impl AprCheckpointFn {
pub fn new(
trainer: SharedTrainer,
model_name: impl Into<String>,
architecture: impl Into<String>,
) -> Self {
Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
}
}
impl CheckpointFn for AprCheckpointFn {
fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
let trainer = self.trainer.borrow();
trainer
.save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
.map_err(|e| format!("save_apr failed: {e}"))
}
}
pub fn build_shared_trainer(lr: f32, seq_length: usize, seed: u64) -> SharedTrainer {
let cfg = llama_370m_train_config(lr, seq_length, seed);
let trainer = TransformerTrainer::new(cfg);
#[cfg(debug_assertions)]
{
let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
debug_assert!(
(366_000_000..=374_000_000).contains(¶m_count),
"INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
);
}
Rc::new(RefCell::new(trainer))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::train::transformer_trainer::LMBatch;
#[test]
fn transformer_config_matches_llama_370m_constants() {
let cfg = llama_370m_transformer_config();
assert_eq!(cfg.hidden_size, Llama370MConfig::HIDDEN_DIM);
assert_eq!(cfg.num_hidden_layers, Llama370MConfig::NUM_LAYERS);
assert_eq!(cfg.num_attention_heads, Llama370MConfig::NUM_HEADS);
assert_eq!(cfg.num_kv_heads, Llama370MConfig::NUM_KV_HEADS);
assert_eq!(cfg.intermediate_size, Llama370MConfig::INTERMEDIATE_DIM);
assert_eq!(cfg.vocab_size, Llama370MConfig::VOCAB_SIZE);
assert!((cfg.rope_theta - Llama370MConfig::ROPE_THETA).abs() < f32::EPSILON);
assert!((cfg.rms_norm_eps - Llama370MConfig::RMS_NORM_EPS).abs() < f32::EPSILON);
assert!(!cfg.use_bias, "INV-ARCH-370M-008: no bias");
assert!(cfg.tie_word_embeddings, "INV-ARCH-370M-004: tied embeddings");
}
#[test]
fn real_step_fn_exhausted_iterator_returns_finite_placeholder() {
let mut tiny = TransformerConfig::llama2_7b();
tiny.hidden_size = 64;
tiny.num_attention_heads = 4;
tiny.num_kv_heads = 4;
tiny.num_hidden_layers = 2;
tiny.intermediate_size = 128;
tiny.vocab_size = 256;
let cfg = TransformerTrainConfig::new(tiny);
let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
let empty_iter: Box<dyn Iterator<Item = LMBatch>> = Box::new(std::iter::empty::<LMBatch>());
let mut step = RealStepFn::new(trainer, empty_iter);
let (loss, grad_norm) = step.step(0, 1.0e-4, 128);
assert!(loss.is_finite(), "exhausted iter must return finite loss");
assert!(grad_norm.is_finite(), "grad_norm must be finite");
assert!(grad_norm >= 0.0, "INV-TRAIN-008: grad_norm non-negative");
}
#[test]
fn real_val_fn_empty_held_out_returns_nan() {
let mut tiny = TransformerConfig::llama2_7b();
tiny.hidden_size = 64;
tiny.num_attention_heads = 4;
tiny.num_kv_heads = 4;
tiny.num_hidden_layers = 2;
tiny.intermediate_size = 128;
tiny.vocab_size = 256;
let cfg = TransformerTrainConfig::new(tiny);
let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
let mut val = RealValFn::new(trainer, Vec::new());
let loss = val.validate(0);
assert!(loss.is_nan(), "empty held_out must surface as NaN to the guard");
}
}