1#![allow(dead_code)] use std::path::{Path, PathBuf};
42use std::time::Instant;
43
44use rand::rngs::StdRng;
45use rand::{Rng, SeedableRng};
46use serde::{Deserialize, Serialize};
47
48#[derive(Debug, Clone, PartialEq, Serialize)]
58pub enum PretrainAbort {
59 Divergence { epoch: usize, prev_val_loss: f32, curr_val_loss: f32, ratio: f32 },
62 DivergenceAtEpochZero { val_loss: f32 },
65 NumericalInstability { step: u64, field: &'static str, value: f32 },
67 ThroughputOutOfRange { step: u64, field: &'static str, value: f32 },
71}
72
73impl std::fmt::Display for PretrainAbort {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 Self::Divergence { epoch, prev_val_loss, curr_val_loss, ratio } => write!(
77 f,
78 "DIVERGENCE at epoch {epoch}: val_loss {curr_val_loss:.4} > 2.0 × {prev_val_loss:.4} (ratio {ratio:.2})",
79 ),
80 Self::DivergenceAtEpochZero { val_loss } => write!(
81 f,
82 "DIVERGENCE at epoch 0: val_loss {val_loss} is non-finite or > 10.0",
83 ),
84 Self::NumericalInstability { step, field, value } => write!(
85 f,
86 "NUMERICAL_INSTABILITY at step {step}: {field} = {value} is non-finite",
87 ),
88 Self::ThroughputOutOfRange { step, field, value } => write!(
89 f,
90 "THROUGHPUT_OUT_OF_RANGE at step {step}: {field} = {value} outside permitted range",
91 ),
92 }
93 }
94}
95
96impl std::error::Error for PretrainAbort {}
97
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109pub struct StepMetrics {
110 pub step: u64,
112 pub train_loss: f32,
114 pub grad_norm: f32,
116 pub lr: f32,
118 pub tokens_per_sec: f32,
120 pub gpu_util_pct: f32,
122 #[serde(default)]
127 pub wall_ms: f32,
128}
129
130impl StepMetrics {
131 pub fn validate_finite(&self) -> Result<(), PretrainAbort> {
137 if !self.train_loss.is_finite() {
138 return Err(PretrainAbort::NumericalInstability {
139 step: self.step,
140 field: "train_loss",
141 value: self.train_loss,
142 });
143 }
144 if !self.grad_norm.is_finite() {
145 return Err(PretrainAbort::NumericalInstability {
146 step: self.step,
147 field: "grad_norm",
148 value: self.grad_norm,
149 });
150 }
151 if !self.lr.is_finite() {
152 return Err(PretrainAbort::NumericalInstability {
153 step: self.step,
154 field: "lr",
155 value: self.lr,
156 });
157 }
158 if !self.tokens_per_sec.is_finite() || self.tokens_per_sec < 0.0 {
159 return Err(PretrainAbort::ThroughputOutOfRange {
160 step: self.step,
161 field: "tokens_per_sec",
162 value: self.tokens_per_sec,
163 });
164 }
165 if !self.gpu_util_pct.is_finite() || self.gpu_util_pct < 0.0 || self.gpu_util_pct > 100.0 {
166 return Err(PretrainAbort::ThroughputOutOfRange {
167 step: self.step,
168 field: "gpu_util_pct",
169 value: self.gpu_util_pct,
170 });
171 }
172 if !self.wall_ms.is_finite() || self.wall_ms < 0.0 {
173 return Err(PretrainAbort::ThroughputOutOfRange {
174 step: self.step,
175 field: "wall_ms",
176 value: self.wall_ms,
177 });
178 }
179 Ok(())
180 }
181}
182
183#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189pub struct EpochMetadata {
190 pub epoch: usize,
191 pub train_loss: f32,
192 pub val_loss: f32,
193 pub train_ppl: f32,
194 pub val_ppl: f32,
195 pub optimizer_state_sha: String,
197 pub wall_seconds: f32,
198 pub tokens_seen: u64,
199 pub grad_norm_max: f32,
200}
201
202#[derive(Debug, Clone)]
205pub struct EpochArtifact {
206 pub checkpoint_path: PathBuf,
208 pub metadata_path: PathBuf,
210 pub metadata: EpochMetadata,
211}
212
213impl EpochArtifact {
214 pub fn new(run_dir: &Path, epoch: usize, metadata: EpochMetadata) -> Self {
216 let ckpt_dir = run_dir.join("ckpt");
217 let filename = format!("epoch-{epoch:03}.apr");
218 let metafile = format!("epoch-{epoch:03}.metadata.json");
219 Self {
220 checkpoint_path: ckpt_dir.join(filename),
221 metadata_path: ckpt_dir.join(metafile),
222 metadata,
223 }
224 }
225}
226
227pub const DIVERGENCE_RATIO_LIMIT: f32 = 2.0;
235
236pub const EPOCH_ZERO_VAL_LOSS_LIMIT: f32 = 10.0;
242
243#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
254#[serde(tag = "kind", rename_all = "snake_case")]
255pub enum TrainingRegime {
256 Finetune,
257 FromScratch { vocab_size: u32 },
258}
259
260impl TrainingRegime {
261 pub fn epoch_zero_val_loss_limit(&self) -> f32 {
268 match self {
269 Self::Finetune => EPOCH_ZERO_VAL_LOSS_LIMIT,
270 Self::FromScratch { vocab_size } => {
271 let v = (*vocab_size).max(2) as f32;
272 DIVERGENCE_RATIO_LIMIT * v.ln()
273 }
274 }
275 }
276}
277
278impl Default for TrainingRegime {
279 fn default() -> Self {
280 Self::Finetune
281 }
282}
283
284pub fn check_non_divergence(
295 epoch: usize,
296 val_loss_history: &[f32],
297 regime: &TrainingRegime,
298) -> Result<(), PretrainAbort> {
299 let Some(&curr) = val_loss_history.get(epoch) else {
300 return Ok(());
302 };
303
304 if epoch == 0 {
306 let cap = regime.epoch_zero_val_loss_limit();
307 if !curr.is_finite() || curr > cap {
308 return Err(PretrainAbort::DivergenceAtEpochZero { val_loss: curr });
309 }
310 return Ok(());
311 }
312
313 let prev = val_loss_history[epoch - 1];
315 if !curr.is_finite() {
316 return Err(PretrainAbort::NumericalInstability {
317 step: u64::MAX,
318 field: "val_loss",
319 value: curr,
320 });
321 }
322 let ratio = curr / prev.max(1e-9);
323 if curr > DIVERGENCE_RATIO_LIMIT * prev {
324 return Err(PretrainAbort::Divergence {
325 epoch,
326 prev_val_loss: prev,
327 curr_val_loss: curr,
328 ratio,
329 });
330 }
331 Ok(())
332}
333
334pub fn check_numerical_stability(
341 step: u64,
342 train_loss: f32,
343 grad_norm: f32,
344) -> Result<(), PretrainAbort> {
345 if !train_loss.is_finite() {
346 return Err(PretrainAbort::NumericalInstability {
347 step,
348 field: "train_loss",
349 value: train_loss,
350 });
351 }
352 if !grad_norm.is_finite() {
353 return Err(PretrainAbort::NumericalInstability {
354 step,
355 field: "grad_norm",
356 value: grad_norm,
357 });
358 }
359 Ok(())
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct PretrainConfig {
370 pub dataset_path: PathBuf,
372 pub tokenizer_dir: PathBuf,
374 pub run_dir: PathBuf,
376 pub lr_max: f32,
378 pub lr_min: f32,
380 pub warmup_steps: usize,
382 pub total_steps: usize,
384 pub batch_size: usize,
386 pub seq_length: usize,
388 pub steps_per_epoch: usize,
391 pub seed: u64,
393 pub grad_clip: f32,
395 pub weight_decay: f32,
397 pub target_val_loss: f32,
399 pub patience_epochs: usize,
401 pub min_epochs_before_early_stop: usize,
403 #[serde(default)]
406 pub regime: TrainingRegime,
407}
408
409impl PretrainConfig {
410 pub fn model_2_defaults(
415 dataset_path: PathBuf,
416 tokenizer_dir: PathBuf,
417 run_dir: PathBuf,
418 ) -> Self {
419 Self {
420 dataset_path,
421 tokenizer_dir,
422 run_dir,
423 lr_max: 5.0e-5,
424 lr_min: 1.0e-6,
425 warmup_steps: 100,
426 total_steps: 1000,
427 batch_size: 16,
428 seq_length: 1024,
429 steps_per_epoch: 100,
430 seed: 42,
431 grad_clip: 1.0,
432 weight_decay: 0.01,
433 target_val_loss: 2.2,
434 patience_epochs: 2,
435 min_epochs_before_early_stop: 3,
436 regime: TrainingRegime::Finetune,
437 }
438 }
439}
440
441#[derive(Debug, Clone, Serialize)]
447pub enum RunStatus {
448 Ok { final_val_loss: f32, epochs_completed: usize },
450 EarlyStop { best_val_loss: f32, epochs_completed: usize },
452 Aborted(PretrainAbort),
454}
455
456pub struct PretrainLoop<S: StepFn, V: ValFn> {
464 config: PretrainConfig,
465 rng: StdRng,
466 step_metrics: Vec<StepMetrics>,
467 epoch_artifacts: Vec<EpochArtifact>,
468 val_loss_history: Vec<f32>,
469 tokens_seen: u64,
470 best_val_loss: f32,
471 patience_counter: usize,
472 step_fn: S,
473 val_fn: V,
474 checkpoint_fn: Option<Box<dyn CheckpointFn>>,
478}
479
480pub trait StepFn {
486 fn step(&mut self, step: u64, lr: f32, batch_tokens: u64) -> (f32, f32);
487
488 fn optimizer_state_sha256(&self) -> Option<String> {
496 None
497 }
498}
499
500pub trait ValFn {
502 fn validate(&mut self, epoch: usize) -> f32;
503}
504
505pub trait CheckpointFn {
516 fn save(&mut self, epoch: usize, artifact: &EpochArtifact) -> Result<(), String>;
517}
518
519impl<S: StepFn, V: ValFn> PretrainLoop<S, V> {
520 pub fn new(config: PretrainConfig, step_fn: S, val_fn: V) -> Self {
522 let rng = StdRng::seed_from_u64(config.seed);
523 Self {
524 config,
525 rng,
526 step_metrics: Vec::new(),
527 epoch_artifacts: Vec::new(),
528 val_loss_history: Vec::new(),
529 tokens_seen: 0,
530 best_val_loss: f32::INFINITY,
531 patience_counter: 0,
532 step_fn,
533 val_fn,
534 checkpoint_fn: None,
535 }
536 }
537
538 #[must_use]
541 pub fn with_checkpoint_fn(mut self, ckpt: Box<dyn CheckpointFn>) -> Self {
542 self.checkpoint_fn = Some(ckpt);
543 self
544 }
545
546 fn lr_at(&self, step: u64) -> f32 {
550 let step = step as usize;
551 let w = self.config.warmup_steps;
552 let total = self.config.total_steps;
553 let lr_max = self.config.lr_max;
554 let lr_min = self.config.lr_min;
555
556 if step < w {
557 if w == 0 {
558 return lr_max;
559 }
560 return lr_max * (step as f32 / w as f32);
561 }
562 let decay_steps = total.saturating_sub(w);
563 if decay_steps == 0 {
564 return lr_min;
565 }
566 let decay_step = step - w;
567 if decay_step >= decay_steps {
568 return lr_min;
569 }
570 let progress = decay_step as f32 / decay_steps as f32;
571 let cosine_decay = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
572 lr_min + (lr_max - lr_min) * cosine_decay
573 }
574
575 pub fn train_step(&mut self, step: u64) -> Result<StepMetrics, PretrainAbort> {
578 let lr = self.lr_at(step);
579 let batch_tokens = (self.config.batch_size * self.config.seq_length) as u64;
580 let t0 = Instant::now();
581 let (train_loss, grad_norm) = self.step_fn.step(step, lr, batch_tokens);
582 let elapsed = t0.elapsed().as_secs_f32().max(1.0e-9);
583
584 check_numerical_stability(step, train_loss, grad_norm)?;
588
589 let tokens_per_sec = batch_tokens as f32 / elapsed;
590 let wall_ms = elapsed * 1000.0;
591 let gpu_util_pct = 50.0 + (self.rng.random_range(-5.0..5.0) as f32);
595
596 let metrics = StepMetrics {
597 step,
598 train_loss,
599 grad_norm,
600 lr,
601 tokens_per_sec,
602 gpu_util_pct: gpu_util_pct.clamp(0.0, 100.0),
603 wall_ms,
604 };
605 metrics.validate_finite()?;
606
607 self.tokens_seen += batch_tokens;
608 self.step_metrics.push(metrics.clone());
609 Ok(metrics)
610 }
611
612 pub fn run_epoch(&mut self, epoch: usize) -> Result<EpochArtifact, PretrainAbort> {
615 let first_step = (epoch * self.config.steps_per_epoch) as u64;
616 let last_step = first_step + self.config.steps_per_epoch as u64;
617
618 let t0 = Instant::now();
619 let mut epoch_loss_sum = 0.0_f32;
620 let mut epoch_grad_norm_max = 0.0_f32;
621 let mut steps_taken = 0_u32;
622
623 for step in first_step..last_step {
624 let m = self.train_step(step)?;
625 epoch_loss_sum += m.train_loss;
626 if m.grad_norm > epoch_grad_norm_max {
627 epoch_grad_norm_max = m.grad_norm;
628 }
629 steps_taken += 1;
630 }
631
632 let mean_train_loss = epoch_loss_sum / steps_taken.max(1) as f32;
633 let val_loss = self.val_fn.validate(epoch);
634
635 if !val_loss.is_finite() {
637 return Err(PretrainAbort::NumericalInstability {
638 step: last_step,
639 field: "val_loss",
640 value: val_loss,
641 });
642 }
643
644 self.val_loss_history.push(val_loss);
645
646 check_non_divergence(epoch, &self.val_loss_history, &self.config.regime)?;
649
650 let wall_seconds = t0.elapsed().as_secs_f32();
651 let optimizer_state_sha =
655 self.step_fn.optimizer_state_sha256().unwrap_or_else(|| self.fake_optimizer_sha(epoch));
656 let metadata = EpochMetadata {
657 epoch,
658 train_loss: mean_train_loss,
659 val_loss,
660 train_ppl: mean_train_loss.exp(),
661 val_ppl: val_loss.exp(),
662 optimizer_state_sha,
663 wall_seconds,
664 tokens_seen: self.tokens_seen,
665 grad_norm_max: epoch_grad_norm_max,
666 };
667 let artifact = EpochArtifact::new(&self.config.run_dir, epoch, metadata);
668
669 if let Some(ckpt) = self.checkpoint_fn.as_mut() {
674 if let Some(parent) = artifact.checkpoint_path.parent() {
675 let _ = std::fs::create_dir_all(parent);
676 }
677 if let Err(e) = ckpt.save(epoch, &artifact) {
678 eprintln!("[pretrain] checkpoint write failed for epoch {}: {}", epoch, e);
679 } else {
680 match serde_json::to_string_pretty(&artifact.metadata) {
684 Ok(json) => {
685 if let Err(e) = std::fs::write(&artifact.metadata_path, json) {
686 eprintln!(
687 "[pretrain] metadata write failed for epoch {}: {}",
688 epoch, e
689 );
690 }
691 }
692 Err(e) => eprintln!(
693 "[pretrain] metadata serialization failed for epoch {}: {}",
694 epoch, e
695 ),
696 }
697 }
698 }
699
700 self.epoch_artifacts.push(artifact.clone());
701 Ok(artifact)
702 }
703
704 pub fn check_convergence(&mut self, epoch: usize) -> bool {
707 let Some(&val_loss) = self.val_loss_history.last() else {
708 return false;
709 };
710 if val_loss < self.best_val_loss {
711 self.best_val_loss = val_loss;
712 self.patience_counter = 0;
713 return false;
714 }
715 self.patience_counter += 1;
716 if epoch + 1 < self.config.min_epochs_before_early_stop {
717 return false;
718 }
719 self.patience_counter > self.config.patience_epochs
720 }
721
722 pub fn run(&mut self) -> RunStatus {
724 let num_epochs = self.config.total_steps.div_ceil(self.config.steps_per_epoch.max(1));
725 for epoch in 0..num_epochs {
726 match self.run_epoch(epoch) {
727 Ok(_) => {}
728 Err(abort) => return RunStatus::Aborted(abort),
729 }
730 if self.check_convergence(epoch) {
731 return RunStatus::EarlyStop {
732 best_val_loss: self.best_val_loss,
733 epochs_completed: epoch + 1,
734 };
735 }
736 let last = *self.val_loss_history.last().unwrap_or(&f32::INFINITY);
737 if last <= self.config.target_val_loss
738 && epoch + 1 >= self.config.min_epochs_before_early_stop
739 {
740 return RunStatus::Ok { final_val_loss: last, epochs_completed: epoch + 1 };
741 }
742 }
743 let last = *self.val_loss_history.last().unwrap_or(&f32::INFINITY);
744 RunStatus::Ok { final_val_loss: last, epochs_completed: num_epochs }
745 }
746
747 pub fn step_metrics(&self) -> &[StepMetrics] {
749 &self.step_metrics
750 }
751
752 pub fn epoch_artifacts(&self) -> &[EpochArtifact] {
753 &self.epoch_artifacts
754 }
755
756 pub fn val_loss_history(&self) -> &[f32] {
757 &self.val_loss_history
758 }
759
760 fn fake_optimizer_sha(&self, epoch: usize) -> String {
766 use sha2::{Digest, Sha256};
767 let mut hasher = Sha256::new();
768 hasher.update(b"aprender-train:pretrain:optstate:v1:");
769 hasher.update(self.config.seed.to_le_bytes());
770 hasher.update((epoch as u64).to_le_bytes());
771 hasher.update(self.tokens_seen.to_le_bytes());
772 format!("{:x}", hasher.finalize())
773 }
774}
775
776pub struct LinearDecaySynthetic {
783 pub start_loss: f32,
784 pub decay_per_step: f32,
785 pub grad_norm: f32,
786}
787
788impl StepFn for LinearDecaySynthetic {
789 fn step(&mut self, step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
790 let loss = (self.start_loss - self.decay_per_step * step as f32).max(1.0e-4);
791 (loss, self.grad_norm)
792 }
793}
794
795pub struct ScriptedVal {
799 pub sequence: Vec<f32>,
800}
801
802impl ValFn for ScriptedVal {
803 fn validate(&mut self, epoch: usize) -> f32 {
804 *self.sequence.get(epoch).unwrap_or(&f32::NAN)
805 }
806}
807
808pub struct NanAtStepSynthetic {
810 pub nan_step: u64,
811}
812
813impl StepFn for NanAtStepSynthetic {
814 fn step(&mut self, step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
815 if step == self.nan_step {
816 return (f32::NAN, 1.0);
817 }
818 (1.0, 1.0)
819 }
820}
821
822#[cfg(test)]
823mod tests {
824 use super::*;
825 use std::cell::RefCell;
826 use std::rc::Rc;
827 use tempfile::TempDir;
828
829 fn test_config(tmp: &Path) -> PretrainConfig {
830 PretrainConfig {
831 dataset_path: tmp.join("data.jsonl"),
832 tokenizer_dir: tmp.join("tok"),
833 run_dir: tmp.join("run"),
834 lr_max: 1.0e-4,
835 lr_min: 1.0e-6,
836 warmup_steps: 2,
837 total_steps: 25,
838 batch_size: 2,
839 seq_length: 4,
840 steps_per_epoch: 5,
841 seed: 42,
842 grad_clip: 1.0,
843 weight_decay: 0.01,
844 target_val_loss: 2.2,
845 patience_epochs: 2,
846 min_epochs_before_early_stop: 1,
847 regime: TrainingRegime::Finetune,
848 }
849 }
850
851 #[test]
856 fn gate_train_005_aborts_on_doubling_val_loss() {
857 let trace = vec![3.5, 7.1];
858 let res = check_non_divergence(1, &trace, &TrainingRegime::Finetune);
859 match res {
860 Err(PretrainAbort::Divergence { epoch, prev_val_loss, curr_val_loss, ratio }) => {
861 assert_eq!(epoch, 1);
862 assert!((prev_val_loss - 3.5).abs() < 1e-6);
863 assert!((curr_val_loss - 7.1).abs() < 1e-6);
864 assert!(ratio > 2.0);
865 }
866 other => panic!("GATE-TRAIN-005 did not abort: got {other:?}"),
867 }
868 }
869
870 #[test]
873 fn gate_train_005_aborts_on_epoch_zero_blowup() {
874 let trace = vec![31.99];
875 let res = check_non_divergence(0, &trace, &TrainingRegime::Finetune);
876 match res {
877 Err(PretrainAbort::DivergenceAtEpochZero { val_loss }) => {
878 assert!((val_loss - 31.99).abs() < 1e-4);
879 }
880 other => panic!("epoch-0 guard missed: got {other:?}"),
881 }
882 }
883
884 #[test]
886 fn gate_train_005_allows_healthy_decrease() {
887 let trace = vec![3.5, 3.0, 2.5, 2.2];
888 for epoch in 0..trace.len() {
889 assert!(check_non_divergence(epoch, &trace, &TrainingRegime::Finetune).is_ok());
890 }
891 }
892
893 #[test]
895 fn gate_train_005_allows_exact_two_x() {
896 let trace = vec![2.0, 4.0];
897 assert!(check_non_divergence(1, &trace, &TrainingRegime::Finetune).is_ok());
898 }
899
900 #[test]
906 fn gate_train_005_from_scratch_permits_near_random_baseline() {
907 let trace = vec![18.0_f32];
908 let regime = TrainingRegime::FromScratch { vocab_size: 50_257 };
909 assert!(
910 check_non_divergence(0, &trace, ®ime).is_ok(),
911 "val_loss[0]=18 must be within 2·ln(50257)≈21.64 from_scratch cap"
912 );
913 assert!(matches!(
915 check_non_divergence(0, &trace, &TrainingRegime::Finetune),
916 Err(PretrainAbort::DivergenceAtEpochZero { .. }),
917 ));
918 }
919
920 #[test]
923 fn gate_train_005_from_scratch_aborts_above_2_ln_vocab() {
924 let trace = vec![25.0_f32];
925 let regime = TrainingRegime::FromScratch { vocab_size: 50_257 };
926 match check_non_divergence(0, &trace, ®ime) {
927 Err(PretrainAbort::DivergenceAtEpochZero { val_loss }) => {
928 assert!((val_loss - 25.0).abs() < 1e-4);
929 }
930 other => panic!("from_scratch cap missed: got {other:?}"),
931 }
932 }
933
934 #[test]
937 fn training_regime_from_scratch_cap_matches_formula() {
938 let v = 50_257u32;
939 let regime = TrainingRegime::FromScratch { vocab_size: v };
940 let expected = DIVERGENCE_RATIO_LIMIT * (v as f32).ln();
941 assert!(
942 (regime.epoch_zero_val_loss_limit() - expected).abs() < 1e-4,
943 "cap formula drift: got {} expected {}",
944 regime.epoch_zero_val_loss_limit(),
945 expected
946 );
947 assert!(
949 (TrainingRegime::Finetune.epoch_zero_val_loss_limit() - EPOCH_ZERO_VAL_LOSS_LIMIT)
950 .abs()
951 < 1e-6
952 );
953 }
954
955 #[test]
958 fn gate_train_007_aborts_on_nan_train_loss() {
959 let res = check_numerical_stability(42, f32::NAN, 1.0);
960 match res {
961 Err(PretrainAbort::NumericalInstability { step, field, .. }) => {
962 assert_eq!(step, 42);
963 assert_eq!(field, "train_loss");
964 }
965 other => panic!("nan guard missed: got {other:?}"),
966 }
967 }
968
969 #[test]
970 fn gate_train_007_aborts_on_inf_grad_norm() {
971 let res = check_numerical_stability(7, 1.0, f32::INFINITY);
972 assert!(matches!(res, Err(PretrainAbort::NumericalInstability { .. })));
973 }
974
975 #[test]
978 fn step_metrics_validate_finite_accepts_healthy() {
979 let m = StepMetrics {
980 step: 0,
981 train_loss: 3.2,
982 grad_norm: 0.5,
983 lr: 1e-4,
984 tokens_per_sec: 1000.0,
985 gpu_util_pct: 75.0,
986 wall_ms: 5.0,
987 };
988 assert!(m.validate_finite().is_ok());
989 }
990
991 #[test]
992 fn step_metrics_rejects_negative_throughput() {
993 let m = StepMetrics {
994 step: 1,
995 train_loss: 3.2,
996 grad_norm: 0.5,
997 lr: 1e-4,
998 tokens_per_sec: -1.0,
999 gpu_util_pct: 75.0,
1000 wall_ms: 5.0,
1001 };
1002 assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
1003 }
1004
1005 #[test]
1006 fn step_metrics_rejects_gpu_util_over_100() {
1007 let m = StepMetrics {
1008 step: 1,
1009 train_loss: 3.2,
1010 grad_norm: 0.5,
1011 lr: 1e-4,
1012 tokens_per_sec: 1000.0,
1013 gpu_util_pct: 150.0,
1014 wall_ms: 5.0,
1015 };
1016 assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
1017 }
1018
1019 #[test]
1022 fn step_metrics_rejects_negative_wall_ms() {
1023 let m = StepMetrics {
1024 step: 1,
1025 train_loss: 3.2,
1026 grad_norm: 0.5,
1027 lr: 1e-4,
1028 tokens_per_sec: 1000.0,
1029 gpu_util_pct: 75.0,
1030 wall_ms: -1.0,
1031 };
1032 assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
1033 }
1034
1035 #[test]
1037 fn step_metrics_rejects_nan_wall_ms() {
1038 let m = StepMetrics {
1039 step: 1,
1040 train_loss: 3.2,
1041 grad_norm: 0.5,
1042 lr: 1e-4,
1043 tokens_per_sec: 1000.0,
1044 gpu_util_pct: 75.0,
1045 wall_ms: f32::NAN,
1046 };
1047 assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
1048 }
1049
1050 #[test]
1055 fn step_metrics_wall_ms_consistent_with_tokens_per_sec() {
1056 let batch_tokens: u64 = 1024;
1057 let elapsed_secs: f32 = 0.5;
1058 let tokens_per_sec = batch_tokens as f32 / elapsed_secs;
1059 let wall_ms = elapsed_secs * 1000.0;
1060
1061 let m = StepMetrics {
1062 step: 0,
1063 train_loss: 3.2,
1064 grad_norm: 0.5,
1065 lr: 1e-4,
1066 tokens_per_sec,
1067 gpu_util_pct: 50.0,
1068 wall_ms,
1069 };
1070 assert!(m.validate_finite().is_ok());
1071 let derived_tokens = m.tokens_per_sec * (m.wall_ms / 1000.0);
1072 let diff = (derived_tokens - batch_tokens as f32).abs();
1073 assert!(
1074 diff < 0.5,
1075 "tokens_per_sec * (wall_ms/1000) = {derived_tokens} should equal batch_tokens={batch_tokens} within FP rounding"
1076 );
1077 }
1078
1079 #[test]
1082 fn pretrain_loop_happy_path_decreasing_loss() {
1083 let tmp = TempDir::new().expect("tempdir");
1084 let cfg = test_config(tmp.path());
1085 let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1086 let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
1087 let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1088
1089 let status = loop_.run();
1090 match status {
1091 RunStatus::Ok { final_val_loss, epochs_completed } => {
1092 assert!(final_val_loss <= 2.2);
1093 assert!(epochs_completed >= 1);
1094 }
1095 other => panic!("healthy run did not converge cleanly: {other:?}"),
1096 }
1097
1098 assert!(!loop_.step_metrics().is_empty());
1100 for m in loop_.step_metrics() {
1101 assert!(m.train_loss.is_finite());
1102 assert!(m.grad_norm.is_finite());
1103 assert!(m.lr.is_finite());
1104 assert!(m.tokens_per_sec >= 0.0);
1105 assert!((0.0..=100.0).contains(&m.gpu_util_pct));
1106 }
1107 assert_eq!(loop_.epoch_artifacts().len(), loop_.val_loss_history().len());
1109 for art in loop_.epoch_artifacts() {
1110 assert!(!art.metadata.optimizer_state_sha.is_empty());
1111 assert!(art.metadata.train_ppl.is_finite());
1112 assert!(art.metadata.val_ppl.is_finite());
1113 }
1114 }
1115
1116 #[test]
1120 fn pretrain_loop_aborts_on_doubling_val_loss() {
1121 let tmp = TempDir::new().expect("tempdir");
1122 let cfg = test_config(tmp.path());
1123 let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1124 let val_fn = ScriptedVal { sequence: vec![3.5, 7.1, 2.0] };
1126 let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1127
1128 let status = loop_.run();
1129 match status {
1130 RunStatus::Aborted(PretrainAbort::Divergence { epoch, ratio, .. }) => {
1131 assert_eq!(epoch, 1);
1132 assert!(ratio > 2.0);
1133 }
1134 other => panic!("GATE-TRAIN-005 did not fire: {other:?}"),
1135 }
1136 }
1137
1138 #[test]
1140 fn pretrain_loop_aborts_on_nan_in_train_loss() {
1141 let tmp = TempDir::new().expect("tempdir");
1142 let cfg = test_config(tmp.path());
1143 let step_fn = NanAtStepSynthetic { nan_step: 3 };
1144 let val_fn = ScriptedVal { sequence: vec![3.0] };
1145 let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1146
1147 let status = loop_.run();
1148 match status {
1149 RunStatus::Aborted(PretrainAbort::NumericalInstability { step, field, .. }) => {
1150 assert_eq!(step, 3);
1151 assert_eq!(field, "train_loss");
1152 }
1153 other => panic!("INV-TRAIN-007 did not fire: {other:?}"),
1154 }
1155 }
1156
1157 #[test]
1161 fn pretrain_loop_reproducibility_seed_42() {
1162 let tmp1 = TempDir::new().expect("tempdir1");
1163 let tmp2 = TempDir::new().expect("tempdir2");
1164 let cfg1 = test_config(tmp1.path());
1165 let cfg2 = test_config(tmp2.path());
1166
1167 let step_fn1 =
1168 LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1169 let step_fn2 =
1170 LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1171 let val_fn1 = ScriptedVal { sequence: vec![3.0, 2.8, 2.6, 2.4, 2.2] };
1172 let val_fn2 = ScriptedVal { sequence: vec![3.0, 2.8, 2.6, 2.4, 2.2] };
1173
1174 let mut loop1 = PretrainLoop::new(cfg1, step_fn1, val_fn1);
1175 let mut loop2 = PretrainLoop::new(cfg2, step_fn2, val_fn2);
1176 let _ = loop1.run();
1177 let _ = loop2.run();
1178
1179 assert_eq!(loop1.step_metrics().len(), loop2.step_metrics().len());
1180 for (a, b) in loop1.step_metrics().iter().zip(loop2.step_metrics().iter()) {
1181 assert_eq!(a.step, b.step);
1182 assert!((a.train_loss - b.train_loss).abs() < 1e-6);
1183 assert!((a.grad_norm - b.grad_norm).abs() < 1e-6);
1184 assert!((a.lr - b.lr).abs() < 1e-6);
1185 assert!((a.gpu_util_pct - b.gpu_util_pct).abs() < 1e-6);
1187 }
1188 }
1189
1190 #[test]
1193 fn lr_schedule_warmup_cosine_boundaries() {
1194 let tmp = TempDir::new().expect("tempdir");
1195 let cfg = PretrainConfig {
1196 warmup_steps: 10,
1197 total_steps: 100,
1198 lr_max: 1.0e-3,
1199 lr_min: 1.0e-5,
1200 ..test_config(tmp.path())
1201 };
1202 let step_fn = LinearDecaySynthetic { start_loss: 1.0, decay_per_step: 0.0, grad_norm: 0.1 };
1203 let val_fn = ScriptedVal { sequence: vec![1.0] };
1204 let loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1205
1206 assert!((loop_.lr_at(0) - 0.0).abs() < 1e-9);
1208 assert!((loop_.lr_at(10) - 1.0e-3).abs() < 1e-6);
1210 assert!((loop_.lr_at(100) - 1.0e-5).abs() < 1e-6);
1212 }
1213
1214 #[test]
1216 fn epoch_artifact_paths_match_contract_template() {
1217 let tmp = TempDir::new().expect("tempdir");
1218 let run_dir = tmp.path().join("run");
1219 let metadata = EpochMetadata {
1220 epoch: 7,
1221 train_loss: 3.0,
1222 val_loss: 2.8,
1223 train_ppl: 20.0,
1224 val_ppl: 16.4,
1225 optimizer_state_sha: "deadbeef".into(),
1226 wall_seconds: 42.0,
1227 tokens_seen: 1_000_000,
1228 grad_norm_max: 1.5,
1229 };
1230 let art = EpochArtifact::new(&run_dir, 7, metadata);
1231 assert!(art.checkpoint_path.ends_with("ckpt/epoch-007.apr"));
1232 assert!(art.metadata_path.ends_with("ckpt/epoch-007.metadata.json"));
1233 }
1234
1235 struct RecordingCheckpointFn {
1241 calls: Rc<RefCell<Vec<(usize, PathBuf)>>>,
1242 }
1243
1244 impl CheckpointFn for RecordingCheckpointFn {
1245 fn save(&mut self, epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
1246 self.calls.borrow_mut().push((epoch, artifact.checkpoint_path.clone()));
1247 Ok(())
1248 }
1249 }
1250
1251 #[test]
1254 fn pretrain_loop_calls_checkpoint_fn_once_per_passing_epoch() {
1255 let tmp = TempDir::new().expect("tempdir");
1256 let cfg = test_config(tmp.path());
1257 let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1258 let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
1259 let calls: Rc<RefCell<Vec<(usize, PathBuf)>>> = Rc::new(RefCell::new(Vec::new()));
1260 let ckpt = RecordingCheckpointFn { calls: Rc::clone(&calls) };
1261
1262 let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn).with_checkpoint_fn(Box::new(ckpt));
1263 let _status = loop_.run();
1264
1265 let recorded = calls.borrow();
1266 let epoch_count = loop_.epoch_artifacts().len();
1267 assert!(epoch_count >= 1, "at least one epoch should have completed");
1268 assert_eq!(
1269 recorded.len(),
1270 epoch_count,
1271 "CheckpointFn must fire exactly once per epoch that passes GATE-TRAIN-005",
1272 );
1273 for (i, (epoch, path)) in recorded.iter().enumerate() {
1274 assert_eq!(*epoch, i, "checkpoint hook epoch indices must be monotonic from 0");
1275 assert!(
1276 path.to_string_lossy().contains(&format!("epoch-{:03}.apr", epoch)),
1277 "checkpoint path must match contract template: {:?}",
1278 path,
1279 );
1280 let meta_path = path.with_extension("metadata.json");
1281 assert!(
1282 meta_path.exists(),
1283 "companion metadata.json must be written for epoch {}",
1284 epoch,
1285 );
1286 }
1287 }
1288
1289 #[test]
1294 fn pretrain_loop_uses_step_fn_optimizer_sha_when_available() {
1295 struct ShaOverride {
1296 inner: LinearDecaySynthetic,
1297 sha: String,
1298 }
1299 impl StepFn for ShaOverride {
1300 fn step(&mut self, s: u64, lr: f32, tokens: u64) -> (f32, f32) {
1301 self.inner.step(s, lr, tokens)
1302 }
1303 fn optimizer_state_sha256(&self) -> Option<String> {
1304 Some(self.sha.clone())
1305 }
1306 }
1307
1308 let tmp = TempDir::new().expect("tempdir");
1309 let cfg = test_config(tmp.path());
1310 let step_fn = ShaOverride {
1311 inner: LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 },
1312 sha: "a".repeat(64),
1313 };
1314 let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
1315 let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1316 let _ = loop_.run();
1317
1318 let arts = loop_.epoch_artifacts();
1319 assert!(!arts.is_empty(), "at least one epoch should have completed");
1320 for art in arts {
1321 assert_eq!(
1322 art.metadata.optimizer_state_sha,
1323 "a".repeat(64),
1324 "StepFn override must win over fake_optimizer_sha fallback",
1325 );
1326 }
1327 }
1328
1329 #[test]
1334 fn pretrain_loop_falls_back_to_fake_optimizer_sha_for_synthetic() {
1335 let tmp = TempDir::new().expect("tempdir");
1336 let cfg = test_config(tmp.path());
1337 let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1338 let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
1339 let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1340 let _ = loop_.run();
1341
1342 for art in loop_.epoch_artifacts() {
1343 assert_eq!(
1344 art.metadata.optimizer_state_sha.len(),
1345 64,
1346 "fallback fingerprint must still be a 64-char hex digest",
1347 );
1348 assert!(
1349 art.metadata.optimizer_state_sha.chars().all(|c| c.is_ascii_hexdigit()),
1350 "fallback fingerprint must be lowercase hex",
1351 );
1352 }
1353 }
1354
1355 #[test]
1358 fn pretrain_loop_skips_checkpoint_on_abort() {
1359 let tmp = TempDir::new().expect("tempdir");
1360 let cfg = test_config(tmp.path());
1361 let step_fn = NanAtStepSynthetic { nan_step: 1 };
1362 let val_fn = ScriptedVal { sequence: vec![3.0] };
1363 let calls: Rc<RefCell<Vec<(usize, PathBuf)>>> = Rc::new(RefCell::new(Vec::new()));
1364 let ckpt = RecordingCheckpointFn { calls: Rc::clone(&calls) };
1365
1366 let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn).with_checkpoint_fn(Box::new(ckpt));
1367 let status = loop_.run();
1368
1369 assert!(
1370 matches!(status, RunStatus::Aborted(PretrainAbort::NumericalInstability { .. })),
1371 "NaN must abort the loop: got {status:?}",
1372 );
1373 assert!(
1374 calls.borrow().is_empty(),
1375 "CheckpointFn must NOT fire when the epoch aborts before GATE-TRAIN-005 passes",
1376 );
1377 }
1378}