#![allow(dead_code)]
use std::path::{Path, PathBuf};
use std::time::Instant;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize)]
pub enum PretrainAbort {
Divergence { epoch: usize, prev_val_loss: f32, curr_val_loss: f32, ratio: f32 },
DivergenceAtEpochZero { val_loss: f32 },
NumericalInstability { step: u64, field: &'static str, value: f32 },
ThroughputOutOfRange { step: u64, field: &'static str, value: f32 },
}
impl std::fmt::Display for PretrainAbort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Divergence { epoch, prev_val_loss, curr_val_loss, ratio } => write!(
f,
"DIVERGENCE at epoch {epoch}: val_loss {curr_val_loss:.4} > 2.0 × {prev_val_loss:.4} (ratio {ratio:.2})",
),
Self::DivergenceAtEpochZero { val_loss } => write!(
f,
"DIVERGENCE at epoch 0: val_loss {val_loss} is non-finite or > 10.0",
),
Self::NumericalInstability { step, field, value } => write!(
f,
"NUMERICAL_INSTABILITY at step {step}: {field} = {value} is non-finite",
),
Self::ThroughputOutOfRange { step, field, value } => write!(
f,
"THROUGHPUT_OUT_OF_RANGE at step {step}: {field} = {value} outside permitted range",
),
}
}
}
impl std::error::Error for PretrainAbort {}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StepMetrics {
pub step: u64,
pub train_loss: f32,
pub grad_norm: f32,
pub lr: f32,
pub tokens_per_sec: f32,
pub gpu_util_pct: f32,
}
impl StepMetrics {
pub fn validate_finite(&self) -> Result<(), PretrainAbort> {
if !self.train_loss.is_finite() {
return Err(PretrainAbort::NumericalInstability {
step: self.step,
field: "train_loss",
value: self.train_loss,
});
}
if !self.grad_norm.is_finite() {
return Err(PretrainAbort::NumericalInstability {
step: self.step,
field: "grad_norm",
value: self.grad_norm,
});
}
if !self.lr.is_finite() {
return Err(PretrainAbort::NumericalInstability {
step: self.step,
field: "lr",
value: self.lr,
});
}
if !self.tokens_per_sec.is_finite() || self.tokens_per_sec < 0.0 {
return Err(PretrainAbort::ThroughputOutOfRange {
step: self.step,
field: "tokens_per_sec",
value: self.tokens_per_sec,
});
}
if !self.gpu_util_pct.is_finite() || self.gpu_util_pct < 0.0 || self.gpu_util_pct > 100.0 {
return Err(PretrainAbort::ThroughputOutOfRange {
step: self.step,
field: "gpu_util_pct",
value: self.gpu_util_pct,
});
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EpochMetadata {
pub epoch: usize,
pub train_loss: f32,
pub val_loss: f32,
pub train_ppl: f32,
pub val_ppl: f32,
pub optimizer_state_sha: String,
pub wall_seconds: f32,
pub tokens_seen: u64,
pub grad_norm_max: f32,
}
#[derive(Debug, Clone)]
pub struct EpochArtifact {
pub checkpoint_path: PathBuf,
pub metadata_path: PathBuf,
pub metadata: EpochMetadata,
}
impl EpochArtifact {
pub fn new(run_dir: &Path, epoch: usize, metadata: EpochMetadata) -> Self {
let ckpt_dir = run_dir.join("ckpt");
let filename = format!("epoch-{epoch:03}.apr");
let metafile = format!("epoch-{epoch:03}.metadata.json");
Self {
checkpoint_path: ckpt_dir.join(filename),
metadata_path: ckpt_dir.join(metafile),
metadata,
}
}
}
pub const DIVERGENCE_RATIO_LIMIT: f32 = 2.0;
pub const EPOCH_ZERO_VAL_LOSS_LIMIT: f32 = 10.0;
pub fn check_non_divergence(epoch: usize, val_loss_history: &[f32]) -> Result<(), PretrainAbort> {
let Some(&curr) = val_loss_history.get(epoch) else {
return Ok(());
};
if epoch == 0 {
if !curr.is_finite() || curr > EPOCH_ZERO_VAL_LOSS_LIMIT {
return Err(PretrainAbort::DivergenceAtEpochZero { val_loss: curr });
}
return Ok(());
}
let prev = val_loss_history[epoch - 1];
if !curr.is_finite() {
return Err(PretrainAbort::NumericalInstability {
step: u64::MAX,
field: "val_loss",
value: curr,
});
}
let ratio = curr / prev.max(1e-9);
if curr > DIVERGENCE_RATIO_LIMIT * prev {
return Err(PretrainAbort::Divergence {
epoch,
prev_val_loss: prev,
curr_val_loss: curr,
ratio,
});
}
Ok(())
}
pub fn check_numerical_stability(
step: u64,
train_loss: f32,
grad_norm: f32,
) -> Result<(), PretrainAbort> {
if !train_loss.is_finite() {
return Err(PretrainAbort::NumericalInstability {
step,
field: "train_loss",
value: train_loss,
});
}
if !grad_norm.is_finite() {
return Err(PretrainAbort::NumericalInstability {
step,
field: "grad_norm",
value: grad_norm,
});
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PretrainConfig {
pub dataset_path: PathBuf,
pub tokenizer_dir: PathBuf,
pub run_dir: PathBuf,
pub lr_max: f32,
pub lr_min: f32,
pub warmup_steps: usize,
pub total_steps: usize,
pub batch_size: usize,
pub seq_length: usize,
pub steps_per_epoch: usize,
pub seed: u64,
pub grad_clip: f32,
pub weight_decay: f32,
pub target_val_loss: f32,
pub patience_epochs: usize,
pub min_epochs_before_early_stop: usize,
}
impl PretrainConfig {
pub fn model_2_defaults(
dataset_path: PathBuf,
tokenizer_dir: PathBuf,
run_dir: PathBuf,
) -> Self {
Self {
dataset_path,
tokenizer_dir,
run_dir,
lr_max: 5.0e-5,
lr_min: 1.0e-6,
warmup_steps: 100,
total_steps: 1000,
batch_size: 16,
seq_length: 1024,
steps_per_epoch: 100,
seed: 42,
grad_clip: 1.0,
weight_decay: 0.01,
target_val_loss: 2.2,
patience_epochs: 2,
min_epochs_before_early_stop: 3,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub enum RunStatus {
Ok { final_val_loss: f32, epochs_completed: usize },
EarlyStop { best_val_loss: f32, epochs_completed: usize },
Aborted(PretrainAbort),
}
pub struct PretrainLoop<S: StepFn, V: ValFn> {
config: PretrainConfig,
rng: StdRng,
step_metrics: Vec<StepMetrics>,
epoch_artifacts: Vec<EpochArtifact>,
val_loss_history: Vec<f32>,
tokens_seen: u64,
best_val_loss: f32,
patience_counter: usize,
step_fn: S,
val_fn: V,
checkpoint_fn: Option<Box<dyn CheckpointFn>>,
}
pub trait StepFn {
fn step(&mut self, step: u64, lr: f32, batch_tokens: u64) -> (f32, f32);
fn optimizer_state_sha256(&self) -> Option<String> {
None
}
}
pub trait ValFn {
fn validate(&mut self, epoch: usize) -> f32;
}
pub trait CheckpointFn {
fn save(&mut self, epoch: usize, artifact: &EpochArtifact) -> Result<(), String>;
}
impl<S: StepFn, V: ValFn> PretrainLoop<S, V> {
pub fn new(config: PretrainConfig, step_fn: S, val_fn: V) -> Self {
let rng = StdRng::seed_from_u64(config.seed);
Self {
config,
rng,
step_metrics: Vec::new(),
epoch_artifacts: Vec::new(),
val_loss_history: Vec::new(),
tokens_seen: 0,
best_val_loss: f32::INFINITY,
patience_counter: 0,
step_fn,
val_fn,
checkpoint_fn: None,
}
}
#[must_use]
pub fn with_checkpoint_fn(mut self, ckpt: Box<dyn CheckpointFn>) -> Self {
self.checkpoint_fn = Some(ckpt);
self
}
fn lr_at(&self, step: u64) -> f32 {
let step = step as usize;
let w = self.config.warmup_steps;
let total = self.config.total_steps;
let lr_max = self.config.lr_max;
let lr_min = self.config.lr_min;
if step < w {
if w == 0 {
return lr_max;
}
return lr_max * (step as f32 / w as f32);
}
let decay_steps = total.saturating_sub(w);
if decay_steps == 0 {
return lr_min;
}
let decay_step = step - w;
if decay_step >= decay_steps {
return lr_min;
}
let progress = decay_step as f32 / decay_steps as f32;
let cosine_decay = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
lr_min + (lr_max - lr_min) * cosine_decay
}
pub fn train_step(&mut self, step: u64) -> Result<StepMetrics, PretrainAbort> {
let lr = self.lr_at(step);
let batch_tokens = (self.config.batch_size * self.config.seq_length) as u64;
let t0 = Instant::now();
let (train_loss, grad_norm) = self.step_fn.step(step, lr, batch_tokens);
let elapsed = t0.elapsed().as_secs_f32().max(1.0e-9);
check_numerical_stability(step, train_loss, grad_norm)?;
let tokens_per_sec = batch_tokens as f32 / elapsed;
let gpu_util_pct = 50.0 + (self.rng.random_range(-5.0..5.0) as f32);
let metrics = StepMetrics {
step,
train_loss,
grad_norm,
lr,
tokens_per_sec,
gpu_util_pct: gpu_util_pct.clamp(0.0, 100.0),
};
metrics.validate_finite()?;
self.tokens_seen += batch_tokens;
self.step_metrics.push(metrics.clone());
Ok(metrics)
}
pub fn run_epoch(&mut self, epoch: usize) -> Result<EpochArtifact, PretrainAbort> {
let first_step = (epoch * self.config.steps_per_epoch) as u64;
let last_step = first_step + self.config.steps_per_epoch as u64;
let t0 = Instant::now();
let mut epoch_loss_sum = 0.0_f32;
let mut epoch_grad_norm_max = 0.0_f32;
let mut steps_taken = 0_u32;
for step in first_step..last_step {
let m = self.train_step(step)?;
epoch_loss_sum += m.train_loss;
if m.grad_norm > epoch_grad_norm_max {
epoch_grad_norm_max = m.grad_norm;
}
steps_taken += 1;
}
let mean_train_loss = epoch_loss_sum / steps_taken.max(1) as f32;
let val_loss = self.val_fn.validate(epoch);
if !val_loss.is_finite() {
return Err(PretrainAbort::NumericalInstability {
step: last_step,
field: "val_loss",
value: val_loss,
});
}
self.val_loss_history.push(val_loss);
check_non_divergence(epoch, &self.val_loss_history)?;
let wall_seconds = t0.elapsed().as_secs_f32();
let optimizer_state_sha =
self.step_fn.optimizer_state_sha256().unwrap_or_else(|| self.fake_optimizer_sha(epoch));
let metadata = EpochMetadata {
epoch,
train_loss: mean_train_loss,
val_loss,
train_ppl: mean_train_loss.exp(),
val_ppl: val_loss.exp(),
optimizer_state_sha,
wall_seconds,
tokens_seen: self.tokens_seen,
grad_norm_max: epoch_grad_norm_max,
};
let artifact = EpochArtifact::new(&self.config.run_dir, epoch, metadata);
if let Some(ckpt) = self.checkpoint_fn.as_mut() {
if let Some(parent) = artifact.checkpoint_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Err(e) = ckpt.save(epoch, &artifact) {
eprintln!("[pretrain] checkpoint write failed for epoch {}: {}", epoch, e);
} else {
match serde_json::to_string_pretty(&artifact.metadata) {
Ok(json) => {
if let Err(e) = std::fs::write(&artifact.metadata_path, json) {
eprintln!(
"[pretrain] metadata write failed for epoch {}: {}",
epoch, e
);
}
}
Err(e) => eprintln!(
"[pretrain] metadata serialization failed for epoch {}: {}",
epoch, e
),
}
}
}
self.epoch_artifacts.push(artifact.clone());
Ok(artifact)
}
pub fn check_convergence(&mut self, epoch: usize) -> bool {
let Some(&val_loss) = self.val_loss_history.last() else {
return false;
};
if val_loss < self.best_val_loss {
self.best_val_loss = val_loss;
self.patience_counter = 0;
return false;
}
self.patience_counter += 1;
if epoch + 1 < self.config.min_epochs_before_early_stop {
return false;
}
self.patience_counter > self.config.patience_epochs
}
pub fn run(&mut self) -> RunStatus {
let num_epochs = self.config.total_steps.div_ceil(self.config.steps_per_epoch.max(1));
for epoch in 0..num_epochs {
match self.run_epoch(epoch) {
Ok(_) => {}
Err(abort) => return RunStatus::Aborted(abort),
}
if self.check_convergence(epoch) {
return RunStatus::EarlyStop {
best_val_loss: self.best_val_loss,
epochs_completed: epoch + 1,
};
}
let last = *self.val_loss_history.last().unwrap_or(&f32::INFINITY);
if last <= self.config.target_val_loss
&& epoch + 1 >= self.config.min_epochs_before_early_stop
{
return RunStatus::Ok { final_val_loss: last, epochs_completed: epoch + 1 };
}
}
let last = *self.val_loss_history.last().unwrap_or(&f32::INFINITY);
RunStatus::Ok { final_val_loss: last, epochs_completed: num_epochs }
}
pub fn step_metrics(&self) -> &[StepMetrics] {
&self.step_metrics
}
pub fn epoch_artifacts(&self) -> &[EpochArtifact] {
&self.epoch_artifacts
}
pub fn val_loss_history(&self) -> &[f32] {
&self.val_loss_history
}
fn fake_optimizer_sha(&self, epoch: usize) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(b"aprender-train:pretrain:optstate:v1:");
hasher.update(self.config.seed.to_le_bytes());
hasher.update((epoch as u64).to_le_bytes());
hasher.update(self.tokens_seen.to_le_bytes());
format!("{:x}", hasher.finalize())
}
}
pub struct LinearDecaySynthetic {
pub start_loss: f32,
pub decay_per_step: f32,
pub grad_norm: f32,
}
impl StepFn for LinearDecaySynthetic {
fn step(&mut self, step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
let loss = (self.start_loss - self.decay_per_step * step as f32).max(1.0e-4);
(loss, self.grad_norm)
}
}
pub struct ScriptedVal {
pub sequence: Vec<f32>,
}
impl ValFn for ScriptedVal {
fn validate(&mut self, epoch: usize) -> f32 {
*self.sequence.get(epoch).unwrap_or(&f32::NAN)
}
}
pub struct NanAtStepSynthetic {
pub nan_step: u64,
}
impl StepFn for NanAtStepSynthetic {
fn step(&mut self, step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
if step == self.nan_step {
return (f32::NAN, 1.0);
}
(1.0, 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::RefCell;
use std::rc::Rc;
use tempfile::TempDir;
fn test_config(tmp: &Path) -> PretrainConfig {
PretrainConfig {
dataset_path: tmp.join("data.jsonl"),
tokenizer_dir: tmp.join("tok"),
run_dir: tmp.join("run"),
lr_max: 1.0e-4,
lr_min: 1.0e-6,
warmup_steps: 2,
total_steps: 25,
batch_size: 2,
seq_length: 4,
steps_per_epoch: 5,
seed: 42,
grad_clip: 1.0,
weight_decay: 0.01,
target_val_loss: 2.2,
patience_epochs: 2,
min_epochs_before_early_stop: 1,
}
}
#[test]
fn gate_train_005_aborts_on_doubling_val_loss() {
let trace = vec![3.5, 7.1];
let res = check_non_divergence(1, &trace);
match res {
Err(PretrainAbort::Divergence { epoch, prev_val_loss, curr_val_loss, ratio }) => {
assert_eq!(epoch, 1);
assert!((prev_val_loss - 3.5).abs() < 1e-6);
assert!((curr_val_loss - 7.1).abs() < 1e-6);
assert!(ratio > 2.0);
}
other => panic!("GATE-TRAIN-005 did not abort: got {other:?}"),
}
}
#[test]
fn gate_train_005_aborts_on_epoch_zero_blowup() {
let trace = vec![31.99];
let res = check_non_divergence(0, &trace);
match res {
Err(PretrainAbort::DivergenceAtEpochZero { val_loss }) => {
assert!((val_loss - 31.99).abs() < 1e-4);
}
other => panic!("epoch-0 guard missed: got {other:?}"),
}
}
#[test]
fn gate_train_005_allows_healthy_decrease() {
let trace = vec![3.5, 3.0, 2.5, 2.2];
for epoch in 0..trace.len() {
assert!(check_non_divergence(epoch, &trace).is_ok());
}
}
#[test]
fn gate_train_005_allows_exact_two_x() {
let trace = vec![2.0, 4.0];
assert!(check_non_divergence(1, &trace).is_ok());
}
#[test]
fn gate_train_007_aborts_on_nan_train_loss() {
let res = check_numerical_stability(42, f32::NAN, 1.0);
match res {
Err(PretrainAbort::NumericalInstability { step, field, .. }) => {
assert_eq!(step, 42);
assert_eq!(field, "train_loss");
}
other => panic!("nan guard missed: got {other:?}"),
}
}
#[test]
fn gate_train_007_aborts_on_inf_grad_norm() {
let res = check_numerical_stability(7, 1.0, f32::INFINITY);
assert!(matches!(res, Err(PretrainAbort::NumericalInstability { .. })));
}
#[test]
fn step_metrics_validate_finite_accepts_healthy() {
let m = StepMetrics {
step: 0,
train_loss: 3.2,
grad_norm: 0.5,
lr: 1e-4,
tokens_per_sec: 1000.0,
gpu_util_pct: 75.0,
};
assert!(m.validate_finite().is_ok());
}
#[test]
fn step_metrics_rejects_negative_throughput() {
let m = StepMetrics {
step: 1,
train_loss: 3.2,
grad_norm: 0.5,
lr: 1e-4,
tokens_per_sec: -1.0,
gpu_util_pct: 75.0,
};
assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
}
#[test]
fn step_metrics_rejects_gpu_util_over_100() {
let m = StepMetrics {
step: 1,
train_loss: 3.2,
grad_norm: 0.5,
lr: 1e-4,
tokens_per_sec: 1000.0,
gpu_util_pct: 150.0,
};
assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
}
#[test]
fn pretrain_loop_happy_path_decreasing_loss() {
let tmp = TempDir::new().expect("tempdir");
let cfg = test_config(tmp.path());
let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
let status = loop_.run();
match status {
RunStatus::Ok { final_val_loss, epochs_completed } => {
assert!(final_val_loss <= 2.2);
assert!(epochs_completed >= 1);
}
other => panic!("healthy run did not converge cleanly: {other:?}"),
}
assert!(!loop_.step_metrics().is_empty());
for m in loop_.step_metrics() {
assert!(m.train_loss.is_finite());
assert!(m.grad_norm.is_finite());
assert!(m.lr.is_finite());
assert!(m.tokens_per_sec >= 0.0);
assert!((0.0..=100.0).contains(&m.gpu_util_pct));
}
assert_eq!(loop_.epoch_artifacts().len(), loop_.val_loss_history().len());
for art in loop_.epoch_artifacts() {
assert!(!art.metadata.optimizer_state_sha.is_empty());
assert!(art.metadata.train_ppl.is_finite());
assert!(art.metadata.val_ppl.is_finite());
}
}
#[test]
fn pretrain_loop_aborts_on_doubling_val_loss() {
let tmp = TempDir::new().expect("tempdir");
let cfg = test_config(tmp.path());
let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
let val_fn = ScriptedVal { sequence: vec![3.5, 7.1, 2.0] };
let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
let status = loop_.run();
match status {
RunStatus::Aborted(PretrainAbort::Divergence { epoch, ratio, .. }) => {
assert_eq!(epoch, 1);
assert!(ratio > 2.0);
}
other => panic!("GATE-TRAIN-005 did not fire: {other:?}"),
}
}
#[test]
fn pretrain_loop_aborts_on_nan_in_train_loss() {
let tmp = TempDir::new().expect("tempdir");
let cfg = test_config(tmp.path());
let step_fn = NanAtStepSynthetic { nan_step: 3 };
let val_fn = ScriptedVal { sequence: vec![3.0] };
let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
let status = loop_.run();
match status {
RunStatus::Aborted(PretrainAbort::NumericalInstability { step, field, .. }) => {
assert_eq!(step, 3);
assert_eq!(field, "train_loss");
}
other => panic!("INV-TRAIN-007 did not fire: {other:?}"),
}
}
#[test]
fn pretrain_loop_reproducibility_seed_42() {
let tmp1 = TempDir::new().expect("tempdir1");
let tmp2 = TempDir::new().expect("tempdir2");
let cfg1 = test_config(tmp1.path());
let cfg2 = test_config(tmp2.path());
let step_fn1 =
LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
let step_fn2 =
LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
let val_fn1 = ScriptedVal { sequence: vec![3.0, 2.8, 2.6, 2.4, 2.2] };
let val_fn2 = ScriptedVal { sequence: vec![3.0, 2.8, 2.6, 2.4, 2.2] };
let mut loop1 = PretrainLoop::new(cfg1, step_fn1, val_fn1);
let mut loop2 = PretrainLoop::new(cfg2, step_fn2, val_fn2);
let _ = loop1.run();
let _ = loop2.run();
assert_eq!(loop1.step_metrics().len(), loop2.step_metrics().len());
for (a, b) in loop1.step_metrics().iter().zip(loop2.step_metrics().iter()) {
assert_eq!(a.step, b.step);
assert!((a.train_loss - b.train_loss).abs() < 1e-6);
assert!((a.grad_norm - b.grad_norm).abs() < 1e-6);
assert!((a.lr - b.lr).abs() < 1e-6);
assert!((a.gpu_util_pct - b.gpu_util_pct).abs() < 1e-6);
}
}
#[test]
fn lr_schedule_warmup_cosine_boundaries() {
let tmp = TempDir::new().expect("tempdir");
let cfg = PretrainConfig {
warmup_steps: 10,
total_steps: 100,
lr_max: 1.0e-3,
lr_min: 1.0e-5,
..test_config(tmp.path())
};
let step_fn = LinearDecaySynthetic { start_loss: 1.0, decay_per_step: 0.0, grad_norm: 0.1 };
let val_fn = ScriptedVal { sequence: vec![1.0] };
let loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
assert!((loop_.lr_at(0) - 0.0).abs() < 1e-9);
assert!((loop_.lr_at(10) - 1.0e-3).abs() < 1e-6);
assert!((loop_.lr_at(100) - 1.0e-5).abs() < 1e-6);
}
#[test]
fn epoch_artifact_paths_match_contract_template() {
let tmp = TempDir::new().expect("tempdir");
let run_dir = tmp.path().join("run");
let metadata = EpochMetadata {
epoch: 7,
train_loss: 3.0,
val_loss: 2.8,
train_ppl: 20.0,
val_ppl: 16.4,
optimizer_state_sha: "deadbeef".into(),
wall_seconds: 42.0,
tokens_seen: 1_000_000,
grad_norm_max: 1.5,
};
let art = EpochArtifact::new(&run_dir, 7, metadata);
assert!(art.checkpoint_path.ends_with("ckpt/epoch-007.apr"));
assert!(art.metadata_path.ends_with("ckpt/epoch-007.metadata.json"));
}
struct RecordingCheckpointFn {
calls: Rc<RefCell<Vec<(usize, PathBuf)>>>,
}
impl CheckpointFn for RecordingCheckpointFn {
fn save(&mut self, epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
self.calls.borrow_mut().push((epoch, artifact.checkpoint_path.clone()));
Ok(())
}
}
#[test]
fn pretrain_loop_calls_checkpoint_fn_once_per_passing_epoch() {
let tmp = TempDir::new().expect("tempdir");
let cfg = test_config(tmp.path());
let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
let calls: Rc<RefCell<Vec<(usize, PathBuf)>>> = Rc::new(RefCell::new(Vec::new()));
let ckpt = RecordingCheckpointFn { calls: Rc::clone(&calls) };
let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn).with_checkpoint_fn(Box::new(ckpt));
let _status = loop_.run();
let recorded = calls.borrow();
let epoch_count = loop_.epoch_artifacts().len();
assert!(epoch_count >= 1, "at least one epoch should have completed");
assert_eq!(
recorded.len(),
epoch_count,
"CheckpointFn must fire exactly once per epoch that passes GATE-TRAIN-005",
);
for (i, (epoch, path)) in recorded.iter().enumerate() {
assert_eq!(*epoch, i, "checkpoint hook epoch indices must be monotonic from 0");
assert!(
path.to_string_lossy().contains(&format!("epoch-{:03}.apr", epoch)),
"checkpoint path must match contract template: {:?}",
path,
);
let meta_path = path.with_extension("metadata.json");
assert!(
meta_path.exists(),
"companion metadata.json must be written for epoch {}",
epoch,
);
}
}
#[test]
fn pretrain_loop_uses_step_fn_optimizer_sha_when_available() {
struct ShaOverride {
inner: LinearDecaySynthetic,
sha: String,
}
impl StepFn for ShaOverride {
fn step(&mut self, s: u64, lr: f32, tokens: u64) -> (f32, f32) {
self.inner.step(s, lr, tokens)
}
fn optimizer_state_sha256(&self) -> Option<String> {
Some(self.sha.clone())
}
}
let tmp = TempDir::new().expect("tempdir");
let cfg = test_config(tmp.path());
let step_fn = ShaOverride {
inner: LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 },
sha: "a".repeat(64),
};
let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
let _ = loop_.run();
let arts = loop_.epoch_artifacts();
assert!(!arts.is_empty(), "at least one epoch should have completed");
for art in arts {
assert_eq!(
art.metadata.optimizer_state_sha,
"a".repeat(64),
"StepFn override must win over fake_optimizer_sha fallback",
);
}
}
#[test]
fn pretrain_loop_falls_back_to_fake_optimizer_sha_for_synthetic() {
let tmp = TempDir::new().expect("tempdir");
let cfg = test_config(tmp.path());
let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
let _ = loop_.run();
for art in loop_.epoch_artifacts() {
assert_eq!(
art.metadata.optimizer_state_sha.len(),
64,
"fallback fingerprint must still be a 64-char hex digest",
);
assert!(
art.metadata.optimizer_state_sha.chars().all(|c| c.is_ascii_hexdigit()),
"fallback fingerprint must be lowercase hex",
);
}
}
#[test]
fn pretrain_loop_skips_checkpoint_on_abort() {
let tmp = TempDir::new().expect("tempdir");
let cfg = test_config(tmp.path());
let step_fn = NanAtStepSynthetic { nan_step: 1 };
let val_fn = ScriptedVal { sequence: vec![3.0] };
let calls: Rc<RefCell<Vec<(usize, PathBuf)>>> = Rc::new(RefCell::new(Vec::new()));
let ckpt = RecordingCheckpointFn { calls: Rc::clone(&calls) };
let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn).with_checkpoint_fn(Box::new(ckpt));
let status = loop_.run();
assert!(
matches!(status, RunStatus::Aborted(PretrainAbort::NumericalInstability { .. })),
"NaN must abort the loop: got {status:?}",
);
assert!(
calls.borrow().is_empty(),
"CheckpointFn must NOT fire when the epoch aborts before GATE-TRAIN-005 passes",
);
}
}