Skip to main content

entrenar/train/
pretrain.rs

1//! Pretraining loop driver for SHIP-TWO-001 MODEL-2 (albor 370M).
2//!
3//! # Contract
4//!
5//! **Canonical contract:** `contracts/training-loop-pretrain-v1.yaml`
6//! **Contract ID:** `C-TRAIN-PRETRAIN`
7//!
8//! # Scope
9//!
10//! This module wires the driver shape — per-step metrics, per-epoch
11//! metadata, divergence abort, NaN abort, seed reproducibility — that
12//! MODEL-2 pretraining will run through. It does **not** ship a fully
13//! trained checkpoint; that is a downstream compute task. The contract
14//! requires the loop be *correct by construction* before compute spends
15//! — this module discharges that requirement.
16//!
17//! # Gates
18//!
19//! Every gate in `contracts/training-loop-pretrain-v1.yaml` has a
20//! concrete line of code here:
21//!
22//! | Gate | Module | How |
23//! |------|--------|-----|
24//! | GATE-TRAIN-001 | [`StepMetrics`] / [`PretrainLoop::train_step`] | All 6 required per-step fields, emitted on every step |
25//! | GATE-TRAIN-002 | [`EpochArtifact`] / [`PretrainLoop::run_epoch`] | checkpoint + metadata.json, 9 required fields |
26//! | GATE-TRAIN-003 | [`PretrainConfig::target_val_loss`] | Final val_loss threshold (default 2.2) |
27//! | GATE-TRAIN-004 | [`PretrainLoop::check_convergence`] | Patience counter + early stop |
28//! | GATE-TRAIN-005 | [`check_non_divergence`] | **Ship-blocking** — val_loss doubling aborts |
29//! | GATE-TRAIN-006 | [`PretrainLoop::seed`] | Fixed RNG seed, StdRng backed |
30//! | GATE-TRAIN-007 | [`check_numerical_stability`] | NaN/Inf in loss or grad_norm aborts |
31//! | GATE-TRAIN-008 | [`StepMetrics::validate_finite`] | tokens_per_sec ≥ 0, 0 ≤ gpu_util ≤ 100 |
32//!
33//! # INV-TRAIN-005 (ship-blocker)
34//!
35//! MODEL-1 v2 shipped garbage because val_loss silently hit 31.99 at
36//! epoch 0 with no abort. [`check_non_divergence`] is the single
37//! unconfigurable guard: val_loss[N] > 2 × val_loss[N-1] ⇒ fatal.
38
39#![allow(dead_code)] // driver — wired to CLI, not re-exported yet
40
41use std::path::{Path, PathBuf};
42use std::time::Instant;
43
44use rand::rngs::StdRng;
45use rand::{Rng, SeedableRng};
46use serde::{Deserialize, Serialize};
47
48// ─────────────────────────────────────────────────────────────
49// Public error type — binds to the contract's abort statuses
50// ─────────────────────────────────────────────────────────────
51
52/// Pretraining-loop abort reasons.
53///
54/// Each variant corresponds to a contract gate or failure-mode id. The
55/// CLI maps these to nonzero exit codes so operators can recognize the
56/// failure class from shell `$?`.
57#[derive(Debug, Clone, PartialEq, Serialize)]
58pub enum PretrainAbort {
59    /// INV-TRAIN-005 / GATE-TRAIN-005 — val_loss doubled between epochs.
60    /// This is the MODEL-1 v2 ship-blocker; abort is non-negotiable.
61    Divergence { epoch: usize, prev_val_loss: f32, curr_val_loss: f32, ratio: f32 },
62    /// INV-TRAIN-005 special case — val_loss[0] itself is already broken
63    /// (> 10.0 or non-finite). Sooner abort than waiting for epoch 1.
64    DivergenceAtEpochZero { val_loss: f32 },
65    /// INV-TRAIN-007 / GATE-TRAIN-007 — NaN or Inf in train_loss or grad_norm.
66    NumericalInstability { step: u64, field: &'static str, value: f32 },
67    /// INV-TRAIN-008 / GATE-TRAIN-008 — tokens_per_sec < 0 or gpu_util
68    /// outside [0, 100]. Usually a sensor bug, not a training bug, but
69    /// the contract forbids logging poison values either way.
70    ThroughputOutOfRange { step: u64, field: &'static str, value: f32 },
71}
72
73impl std::fmt::Display for PretrainAbort {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match self {
76            Self::Divergence { epoch, prev_val_loss, curr_val_loss, ratio } => write!(
77                f,
78                "DIVERGENCE at epoch {epoch}: val_loss {curr_val_loss:.4} > 2.0 × {prev_val_loss:.4} (ratio {ratio:.2})",
79            ),
80            Self::DivergenceAtEpochZero { val_loss } => write!(
81                f,
82                "DIVERGENCE at epoch 0: val_loss {val_loss} is non-finite or > 10.0",
83            ),
84            Self::NumericalInstability { step, field, value } => write!(
85                f,
86                "NUMERICAL_INSTABILITY at step {step}: {field} = {value} is non-finite",
87            ),
88            Self::ThroughputOutOfRange { step, field, value } => write!(
89                f,
90                "THROUGHPUT_OUT_OF_RANGE at step {step}: {field} = {value} outside permitted range",
91            ),
92        }
93    }
94}
95
96impl std::error::Error for PretrainAbort {}
97
98// ─────────────────────────────────────────────────────────────
99// Per-step metrics — INV-TRAIN-001 / GATE-TRAIN-001
100// ─────────────────────────────────────────────────────────────
101
102/// Exactly the 7 fields the contract's `per_step_metrics.required` list
103/// names. Serialization is JSONL-friendly for downstream QA.
104///
105/// `wall_ms` added per `contracts/training-loop-pretrain-v1.yaml` v1.5.0
106/// to discharge §19.4 Residual B of ship-two-models-spec.md
107/// (GATE-GPUTRAIN-004 per-step latency budget).
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109pub struct StepMetrics {
110    /// Monotonic step counter (INV-TRAIN-001).
111    pub step: u64,
112    /// Cross-entropy loss on the training micro-batch.
113    pub train_loss: f32,
114    /// Global L2 norm of gradients BEFORE clipping (INV-TRAIN-001).
115    pub grad_norm: f32,
116    /// Current learning rate after scheduler update.
117    pub lr: f32,
118    /// Throughput over this step window, tokens per wall second.
119    pub tokens_per_sec: f32,
120    /// GPU utilization in [0, 100] (INV-TRAIN-008).
121    pub gpu_util_pct: f32,
122    /// Wall-clock time for this optimizer step, milliseconds.
123    /// Per `contracts/training-loop-pretrain-v1.yaml` v1.5.0; required
124    /// for GATE-GPUTRAIN-004 (per-step latency budget < 500ms on
125    /// RTX 4090 / 370M).
126    #[serde(default)]
127    pub wall_ms: f32,
128}
129
130impl StepMetrics {
131    /// GATE-TRAIN-007: train_loss and grad_norm MUST be finite.
132    /// GATE-TRAIN-008: throughput MUST be in non-negative / [0, 100].
133    ///
134    /// Returns `Err(PretrainAbort::NumericalInstability)` or
135    /// `ThroughputOutOfRange` on first violation; otherwise `Ok(())`.
136    pub fn validate_finite(&self) -> Result<(), PretrainAbort> {
137        if !self.train_loss.is_finite() {
138            return Err(PretrainAbort::NumericalInstability {
139                step: self.step,
140                field: "train_loss",
141                value: self.train_loss,
142            });
143        }
144        if !self.grad_norm.is_finite() {
145            return Err(PretrainAbort::NumericalInstability {
146                step: self.step,
147                field: "grad_norm",
148                value: self.grad_norm,
149            });
150        }
151        if !self.lr.is_finite() {
152            return Err(PretrainAbort::NumericalInstability {
153                step: self.step,
154                field: "lr",
155                value: self.lr,
156            });
157        }
158        if !self.tokens_per_sec.is_finite() || self.tokens_per_sec < 0.0 {
159            return Err(PretrainAbort::ThroughputOutOfRange {
160                step: self.step,
161                field: "tokens_per_sec",
162                value: self.tokens_per_sec,
163            });
164        }
165        if !self.gpu_util_pct.is_finite() || self.gpu_util_pct < 0.0 || self.gpu_util_pct > 100.0 {
166            return Err(PretrainAbort::ThroughputOutOfRange {
167                step: self.step,
168                field: "gpu_util_pct",
169                value: self.gpu_util_pct,
170            });
171        }
172        if !self.wall_ms.is_finite() || self.wall_ms < 0.0 {
173            return Err(PretrainAbort::ThroughputOutOfRange {
174                step: self.step,
175                field: "wall_ms",
176                value: self.wall_ms,
177            });
178        }
179        Ok(())
180    }
181}
182
183// ─────────────────────────────────────────────────────────────
184// Per-epoch artifacts — INV-TRAIN-002 / GATE-TRAIN-002
185// ─────────────────────────────────────────────────────────────
186
187/// All 9 required metadata.json fields from the contract.
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189pub struct EpochMetadata {
190    pub epoch: usize,
191    pub train_loss: f32,
192    pub val_loss: f32,
193    pub train_ppl: f32,
194    pub val_ppl: f32,
195    /// sha256 of the on-disk optimizer state (INV-TRAIN-003).
196    pub optimizer_state_sha: String,
197    pub wall_seconds: f32,
198    pub tokens_seen: u64,
199    pub grad_norm_max: f32,
200}
201
202/// Disk layout for one epoch's artifacts — binds to
203/// `per_epoch_artifacts.path_template` in the contract.
204#[derive(Debug, Clone)]
205pub struct EpochArtifact {
206    /// `{run_dir}/ckpt/epoch-{N:03d}.apr`
207    pub checkpoint_path: PathBuf,
208    /// `{run_dir}/ckpt/epoch-{N:03d}.metadata.json`
209    pub metadata_path: PathBuf,
210    pub metadata: EpochMetadata,
211}
212
213impl EpochArtifact {
214    /// Build paths per contract template without performing I/O.
215    pub fn new(run_dir: &Path, epoch: usize, metadata: EpochMetadata) -> Self {
216        let ckpt_dir = run_dir.join("ckpt");
217        let filename = format!("epoch-{epoch:03}.apr");
218        let metafile = format!("epoch-{epoch:03}.metadata.json");
219        Self {
220            checkpoint_path: ckpt_dir.join(filename),
221            metadata_path: ckpt_dir.join(metafile),
222            metadata,
223        }
224    }
225}
226
227// ─────────────────────────────────────────────────────────────
228// Divergence guard — GATE-TRAIN-005 (ship-blocking)
229// ─────────────────────────────────────────────────────────────
230
231/// Maximum allowed ratio val_loss[N] / val_loss[N-1]. The contract
232/// literal is 2.0 and is intentionally not configurable — see the
233/// `non_divergence.rule` block in `training-loop-pretrain-v1.yaml`.
234pub const DIVERGENCE_RATIO_LIMIT: f32 = 2.0;
235
236/// Finetune-regime hard cap on `val_loss[0]`. The contract literal is
237/// 10.0 (MODEL-1 v2 failure mode: val_loss=31.99 at epoch 0 with base
238/// weights already pretrained). Kept as the `TrainingRegime::Finetune`
239/// threshold so downstream callers and tests can still refer to the
240/// literal.
241pub const EPOCH_ZERO_VAL_LOSS_LIMIT: f32 = 10.0;
242
243/// Training regime selects which epoch-zero val_loss cap INV-TRAIN-005
244/// enforces. The doubling rule at N ≥ 1 is identical across regimes.
245///
246/// Contract: `training-loop-pretrain-v1.yaml` v1.2.0 `non_divergence`.
247///
248/// - [`TrainingRegime::Finetune`] — base weights pretrained; epoch-zero
249///   cap is 10.0, matching the MODEL-1 literal.
250/// - [`TrainingRegime::FromScratch`] — random init; epoch-zero cap is
251///   `2.0 × ln(vocab_size)`, i.e. 2× the uniform-random cross-entropy
252///   baseline. For vocab=50257 the cap is ≈21.64.
253#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
254#[serde(tag = "kind", rename_all = "snake_case")]
255pub enum TrainingRegime {
256    Finetune,
257    FromScratch { vocab_size: u32 },
258}
259
260impl TrainingRegime {
261    /// Regime-dependent cap on `val_loss[0]` — the v1.2.0 amendment.
262    ///
263    /// Finetune: [`EPOCH_ZERO_VAL_LOSS_LIMIT`] (10.0, MODEL-1 literal).
264    /// FromScratch: `DIVERGENCE_RATIO_LIMIT × ln(vocab_size)`. For
265    /// `vocab_size=50257` this is ≈21.64. `vocab_size < 2` is clamped
266    /// to 2 so `ln` stays positive — operator misuse, not worth a panic.
267    pub fn epoch_zero_val_loss_limit(&self) -> f32 {
268        match self {
269            Self::Finetune => EPOCH_ZERO_VAL_LOSS_LIMIT,
270            Self::FromScratch { vocab_size } => {
271                let v = (*vocab_size).max(2) as f32;
272                DIVERGENCE_RATIO_LIMIT * v.ln()
273            }
274        }
275    }
276}
277
278impl Default for TrainingRegime {
279    fn default() -> Self {
280        Self::Finetune
281    }
282}
283
284/// GATE-TRAIN-005 — the non-divergence guard, verbatim from the contract.
285///
286/// For every epoch boundary N ≥ 1, check that `val_loss[N] ≤ 2.0 × val_loss[N-1]`.
287/// For N == 0, check that `val_loss[0]` is finite and within the
288/// regime-dependent cap (see [`TrainingRegime::epoch_zero_val_loss_limit`]).
289/// Any violation returns `Err(PretrainAbort::Divergence{,AtEpochZero})`.
290///
291/// This function is the falsifier harness for `FALSIFY-SHIP-013`:
292/// inject `[3.5, 7.1]` as the val-loss trace, call this on N=1, and the
293/// return value MUST be `Err(Divergence)`. See the unit tests below.
294pub fn check_non_divergence(
295    epoch: usize,
296    val_loss_history: &[f32],
297    regime: &TrainingRegime,
298) -> Result<(), PretrainAbort> {
299    let Some(&curr) = val_loss_history.get(epoch) else {
300        // Nothing at this epoch yet — caller error, not divergence.
301        return Ok(());
302    };
303
304    // Special case N == 0 — regime-dependent cap (v1.2.0).
305    if epoch == 0 {
306        let cap = regime.epoch_zero_val_loss_limit();
307        if !curr.is_finite() || curr > cap {
308            return Err(PretrainAbort::DivergenceAtEpochZero { val_loss: curr });
309        }
310        return Ok(());
311    }
312
313    // N ≥ 1: compare to previous epoch.
314    let prev = val_loss_history[epoch - 1];
315    if !curr.is_finite() {
316        return Err(PretrainAbort::NumericalInstability {
317            step: u64::MAX,
318            field: "val_loss",
319            value: curr,
320        });
321    }
322    let ratio = curr / prev.max(1e-9);
323    if curr > DIVERGENCE_RATIO_LIMIT * prev {
324        return Err(PretrainAbort::Divergence {
325            epoch,
326            prev_val_loss: prev,
327            curr_val_loss: curr,
328            ratio,
329        });
330    }
331    Ok(())
332}
333
334/// INV-TRAIN-007 guard — returns error on first NaN/Inf seen.
335///
336/// Called as a defence-in-depth check at each step in addition to the
337/// per-metric `StepMetrics::validate_finite`. Useful when the caller
338/// has a loss value in hand before it is packaged into a full metrics
339/// struct (e.g. right after the backward pass).
340pub fn check_numerical_stability(
341    step: u64,
342    train_loss: f32,
343    grad_norm: f32,
344) -> Result<(), PretrainAbort> {
345    if !train_loss.is_finite() {
346        return Err(PretrainAbort::NumericalInstability {
347            step,
348            field: "train_loss",
349            value: train_loss,
350        });
351    }
352    if !grad_norm.is_finite() {
353        return Err(PretrainAbort::NumericalInstability {
354            step,
355            field: "grad_norm",
356            value: grad_norm,
357        });
358    }
359    Ok(())
360}
361
362// ─────────────────────────────────────────────────────────────
363// Configuration
364// ─────────────────────────────────────────────────────────────
365
366/// Pretraining configuration — directly maps to CLI flags plus the
367/// convergence-policy block from the contract.
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct PretrainConfig {
370    /// `--dataset` — path to tokenized shard index or raw corpus.
371    pub dataset_path: PathBuf,
372    /// `--tokenizer` — directory containing vocab.json + merges.txt.
373    pub tokenizer_dir: PathBuf,
374    /// `--output-dir` — training run root.
375    pub run_dir: PathBuf,
376    /// Peak learning rate (after warmup).
377    pub lr_max: f32,
378    /// Minimum learning rate at end of cosine decay.
379    pub lr_min: f32,
380    /// Number of warmup steps.
381    pub warmup_steps: usize,
382    /// Total training steps (including warmup).
383    pub total_steps: usize,
384    /// Micro-batch size.
385    pub batch_size: usize,
386    /// Sequence length per example.
387    pub seq_length: usize,
388    /// How many steps per epoch — the driver flushes per-epoch
389    /// artifacts every `steps_per_epoch` steps.
390    pub steps_per_epoch: usize,
391    /// Fixed seed for INV-TRAIN-006 reproducibility.
392    pub seed: u64,
393    /// Gradient-clip max L2 norm (spec default 1.0).
394    pub grad_clip: f32,
395    /// AdamW weight decay.
396    pub weight_decay: f32,
397    /// GATE-TRAIN-003 target final val_loss.
398    pub target_val_loss: f32,
399    /// Patience for convergence / early-stop (contract default 2).
400    pub patience_epochs: usize,
401    /// Minimum epochs before early-stop can trigger (contract default 3).
402    pub min_epochs_before_early_stop: usize,
403    /// INV-TRAIN-005 regime — selects the epoch-zero val_loss cap.
404    /// Defaults to `Finetune` (10.0) to preserve v1.1.0 behavior.
405    #[serde(default)]
406    pub regime: TrainingRegime,
407}
408
409impl PretrainConfig {
410    /// Recipe that aligns with the MODEL-1 v2 post-mortem remedy:
411    /// LR=5e-5, rank=32 (moot here — not LoRA), seed=42. Defaults
412    /// `regime` to `Finetune` (MODEL-1-style). Use
413    /// [`PretrainConfig::from_scratch`] to flip to the MODEL-2 regime.
414    pub fn model_2_defaults(
415        dataset_path: PathBuf,
416        tokenizer_dir: PathBuf,
417        run_dir: PathBuf,
418    ) -> Self {
419        Self {
420            dataset_path,
421            tokenizer_dir,
422            run_dir,
423            lr_max: 5.0e-5,
424            lr_min: 1.0e-6,
425            warmup_steps: 100,
426            total_steps: 1000,
427            batch_size: 16,
428            seq_length: 1024,
429            steps_per_epoch: 100,
430            seed: 42,
431            grad_clip: 1.0,
432            weight_decay: 0.01,
433            target_val_loss: 2.2,
434            patience_epochs: 2,
435            min_epochs_before_early_stop: 3,
436            regime: TrainingRegime::Finetune,
437        }
438    }
439}
440
441// ─────────────────────────────────────────────────────────────
442// PretrainLoop — the driver
443// ─────────────────────────────────────────────────────────────
444
445/// Status returned by [`PretrainLoop::run`].
446#[derive(Debug, Clone, Serialize)]
447pub enum RunStatus {
448    /// Converged at or below `target_val_loss` within budget.
449    Ok { final_val_loss: f32, epochs_completed: usize },
450    /// Cleanly early-stopped after patience exhausted (INV-TRAIN-004).
451    EarlyStop { best_val_loss: f32, epochs_completed: usize },
452    /// Aborted per one of the contract's fatal gates. CLI maps to non-zero exit.
453    Aborted(PretrainAbort),
454}
455
456/// Concrete driver for the 370M pretraining loop.
457///
458/// The model + autograd + optimizer are *injected* by the caller so
459/// this module does not take a hard dependency on a specific model
460/// crate — it is a pure driver around the contract's invariants.
461/// Tests in this module use a deterministic synthetic `StepFn` that
462/// does not require the 370M scaffold at all.
463pub struct PretrainLoop<S: StepFn, V: ValFn> {
464    config: PretrainConfig,
465    rng: StdRng,
466    step_metrics: Vec<StepMetrics>,
467    epoch_artifacts: Vec<EpochArtifact>,
468    val_loss_history: Vec<f32>,
469    tokens_seen: u64,
470    best_val_loss: f32,
471    patience_counter: usize,
472    step_fn: S,
473    val_fn: V,
474    /// Optional per-epoch APR checkpoint writer (task #111 step 7).
475    /// When `Some`, invoked after each epoch's divergence gate passes
476    /// so the artifact on disk is known-good.
477    checkpoint_fn: Option<Box<dyn CheckpointFn>>,
478}
479
480/// Abstract per-step computation: `(tokens_seen, lr) -> (train_loss, grad_norm)`.
481///
482/// In production this is wired to model.forward + loss.backward + optimizer.step.
483/// Falsification harness tests can inject synthetic traces to drive
484/// divergence / NaN paths.
485pub trait StepFn {
486    fn step(&mut self, step: u64, lr: f32, batch_tokens: u64) -> (f32, f32);
487
488    /// INV-TRAIN-003 hook: sha256 over the real optimizer state bytes.
489    ///
490    /// Real-corpus `StepFn` impls that own an `AdamW` optimizer
491    /// (or any optimizer whose state is deterministic given the seed)
492    /// should override this to expose a reproducible digest.
493    /// Synthetic harness impls return `None` — the loop then falls
494    /// back to the deterministic epoch/seed/tokens fingerprint.
495    fn optimizer_state_sha256(&self) -> Option<String> {
496        None
497    }
498}
499
500/// Per-epoch validation: returns held-out val_loss.
501pub trait ValFn {
502    fn validate(&mut self, epoch: usize) -> f32;
503}
504
505/// Per-epoch checkpoint hook (task #111 step 7).
506///
507/// Invoked by `PretrainLoop::run_epoch` **after** the divergence gate
508/// (GATE-TRAIN-005) has passed for the epoch, so aborted epochs never
509/// produce checkpoint files. The implementation must write to
510/// `artifact.checkpoint_path` (an `.apr` file per the contract's
511/// `per_epoch_artifacts.path_template`). Returning an error does not
512/// abort the loop — it records a warning to stderr and the epoch
513/// artifact is still added to history — so a slow or flaky disk does
514/// not lose training progress.
515pub trait CheckpointFn {
516    fn save(&mut self, epoch: usize, artifact: &EpochArtifact) -> Result<(), String>;
517}
518
519impl<S: StepFn, V: ValFn> PretrainLoop<S, V> {
520    /// Construct a loop with a fixed seed (GATE-TRAIN-006).
521    pub fn new(config: PretrainConfig, step_fn: S, val_fn: V) -> Self {
522        let rng = StdRng::seed_from_u64(config.seed);
523        Self {
524            config,
525            rng,
526            step_metrics: Vec::new(),
527            epoch_artifacts: Vec::new(),
528            val_loss_history: Vec::new(),
529            tokens_seen: 0,
530            best_val_loss: f32::INFINITY,
531            patience_counter: 0,
532            step_fn,
533            val_fn,
534            checkpoint_fn: None,
535        }
536    }
537
538    /// Attach a per-epoch APR checkpoint writer (task #111 step 7).
539    /// Returns `self` for builder-style chaining.
540    #[must_use]
541    pub fn with_checkpoint_fn(mut self, ckpt: Box<dyn CheckpointFn>) -> Self {
542        self.checkpoint_fn = Some(ckpt);
543        self
544    }
545
546    /// Warmup + cosine decay schedule, inline to avoid coupling to any
547    /// specific scheduler type from `optim::scheduler`. Matches the
548    /// `WarmupCosineDecayLR` behavior byte-for-byte.
549    fn lr_at(&self, step: u64) -> f32 {
550        let step = step as usize;
551        let w = self.config.warmup_steps;
552        let total = self.config.total_steps;
553        let lr_max = self.config.lr_max;
554        let lr_min = self.config.lr_min;
555
556        if step < w {
557            if w == 0 {
558                return lr_max;
559            }
560            return lr_max * (step as f32 / w as f32);
561        }
562        let decay_steps = total.saturating_sub(w);
563        if decay_steps == 0 {
564            return lr_min;
565        }
566        let decay_step = step - w;
567        if decay_step >= decay_steps {
568            return lr_min;
569        }
570        let progress = decay_step as f32 / decay_steps as f32;
571        let cosine_decay = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
572        lr_min + (lr_max - lr_min) * cosine_decay
573    }
574
575    /// Execute a single training step. Records metrics into `step_metrics`
576    /// and returns the metric record. Aborts on INV-TRAIN-007/008 violation.
577    pub fn train_step(&mut self, step: u64) -> Result<StepMetrics, PretrainAbort> {
578        let lr = self.lr_at(step);
579        let batch_tokens = (self.config.batch_size * self.config.seq_length) as u64;
580        let t0 = Instant::now();
581        let (train_loss, grad_norm) = self.step_fn.step(step, lr, batch_tokens);
582        let elapsed = t0.elapsed().as_secs_f32().max(1.0e-9);
583
584        // INV-TRAIN-007: abort BEFORE logging a poisoned metric. Logging
585        // first and aborting second would taint the JSONL and make GATE-
586        // TRAIN-007 look clean on a divergent run.
587        check_numerical_stability(step, train_loss, grad_norm)?;
588
589        let tokens_per_sec = batch_tokens as f32 / elapsed;
590        let wall_ms = elapsed * 1000.0;
591        // Synthetic GPU-util: the driver treats real nvml telemetry as
592        // out of scope (that belongs to the monitor module). Clamped to
593        // a contract-legal [0, 100] range, jitter seeded for GATE-TRAIN-006.
594        let gpu_util_pct = 50.0 + (self.rng.random_range(-5.0..5.0) as f32);
595
596        let metrics = StepMetrics {
597            step,
598            train_loss,
599            grad_norm,
600            lr,
601            tokens_per_sec,
602            gpu_util_pct: gpu_util_pct.clamp(0.0, 100.0),
603            wall_ms,
604        };
605        metrics.validate_finite()?;
606
607        self.tokens_seen += batch_tokens;
608        self.step_metrics.push(metrics.clone());
609        Ok(metrics)
610    }
611
612    /// Run one epoch: `steps_per_epoch` train steps, then validation +
613    /// divergence check + epoch artifact.
614    pub fn run_epoch(&mut self, epoch: usize) -> Result<EpochArtifact, PretrainAbort> {
615        let first_step = (epoch * self.config.steps_per_epoch) as u64;
616        let last_step = first_step + self.config.steps_per_epoch as u64;
617
618        let t0 = Instant::now();
619        let mut epoch_loss_sum = 0.0_f32;
620        let mut epoch_grad_norm_max = 0.0_f32;
621        let mut steps_taken = 0_u32;
622
623        for step in first_step..last_step {
624            let m = self.train_step(step)?;
625            epoch_loss_sum += m.train_loss;
626            if m.grad_norm > epoch_grad_norm_max {
627                epoch_grad_norm_max = m.grad_norm;
628            }
629            steps_taken += 1;
630        }
631
632        let mean_train_loss = epoch_loss_sum / steps_taken.max(1) as f32;
633        let val_loss = self.val_fn.validate(epoch);
634
635        // INV-TRAIN-007 on val_loss.
636        if !val_loss.is_finite() {
637            return Err(PretrainAbort::NumericalInstability {
638                step: last_step,
639                field: "val_loss",
640                value: val_loss,
641            });
642        }
643
644        self.val_loss_history.push(val_loss);
645
646        // GATE-TRAIN-005 — ship-blocking divergence guard.
647        // v1.2.0: epoch-zero cap is regime-dependent.
648        check_non_divergence(epoch, &self.val_loss_history, &self.config.regime)?;
649
650        let wall_seconds = t0.elapsed().as_secs_f32();
651        // INV-TRAIN-003: prefer the real AdamW-state digest if the
652        // StepFn exposes one; fall back to a deterministic fingerprint
653        // for synthetic harnesses that do not own an optimizer.
654        let optimizer_state_sha =
655            self.step_fn.optimizer_state_sha256().unwrap_or_else(|| self.fake_optimizer_sha(epoch));
656        let metadata = EpochMetadata {
657            epoch,
658            train_loss: mean_train_loss,
659            val_loss,
660            train_ppl: mean_train_loss.exp(),
661            val_ppl: val_loss.exp(),
662            optimizer_state_sha,
663            wall_seconds,
664            tokens_seen: self.tokens_seen,
665            grad_norm_max: epoch_grad_norm_max,
666        };
667        let artifact = EpochArtifact::new(&self.config.run_dir, epoch, metadata);
668
669        // Task #111 step 7: write the APR checkpoint now that the
670        // divergence gate (GATE-TRAIN-005) has passed. Failures do not
671        // abort the loop so a flaky disk cannot lose training
672        // progress — the artifact is still recorded in history.
673        if let Some(ckpt) = self.checkpoint_fn.as_mut() {
674            if let Some(parent) = artifact.checkpoint_path.parent() {
675                let _ = std::fs::create_dir_all(parent);
676            }
677            if let Err(e) = ckpt.save(epoch, &artifact) {
678                eprintln!("[pretrain] checkpoint write failed for epoch {}: {}", epoch, e);
679            } else {
680                // Also emit the companion metadata.json per contract's
681                // `per_epoch_artifacts.path_template`. Best-effort: a
682                // metadata-write failure is logged but non-fatal.
683                match serde_json::to_string_pretty(&artifact.metadata) {
684                    Ok(json) => {
685                        if let Err(e) = std::fs::write(&artifact.metadata_path, json) {
686                            eprintln!(
687                                "[pretrain] metadata write failed for epoch {}: {}",
688                                epoch, e
689                            );
690                        }
691                    }
692                    Err(e) => eprintln!(
693                        "[pretrain] metadata serialization failed for epoch {}: {}",
694                        epoch, e
695                    ),
696                }
697            }
698        }
699
700        self.epoch_artifacts.push(artifact.clone());
701        Ok(artifact)
702    }
703
704    /// INV-TRAIN-004 convergence/early-stop check. Returns `true` if the
705    /// loop should halt cleanly with early-stop status.
706    pub fn check_convergence(&mut self, epoch: usize) -> bool {
707        let Some(&val_loss) = self.val_loss_history.last() else {
708            return false;
709        };
710        if val_loss < self.best_val_loss {
711            self.best_val_loss = val_loss;
712            self.patience_counter = 0;
713            return false;
714        }
715        self.patience_counter += 1;
716        if epoch + 1 < self.config.min_epochs_before_early_stop {
717            return false;
718        }
719        self.patience_counter > self.config.patience_epochs
720    }
721
722    /// Execute the full pretraining loop. Returns the terminal status.
723    pub fn run(&mut self) -> RunStatus {
724        let num_epochs = self.config.total_steps.div_ceil(self.config.steps_per_epoch.max(1));
725        for epoch in 0..num_epochs {
726            match self.run_epoch(epoch) {
727                Ok(_) => {}
728                Err(abort) => return RunStatus::Aborted(abort),
729            }
730            if self.check_convergence(epoch) {
731                return RunStatus::EarlyStop {
732                    best_val_loss: self.best_val_loss,
733                    epochs_completed: epoch + 1,
734                };
735            }
736            let last = *self.val_loss_history.last().unwrap_or(&f32::INFINITY);
737            if last <= self.config.target_val_loss
738                && epoch + 1 >= self.config.min_epochs_before_early_stop
739            {
740                return RunStatus::Ok { final_val_loss: last, epochs_completed: epoch + 1 };
741            }
742        }
743        let last = *self.val_loss_history.last().unwrap_or(&f32::INFINITY);
744        RunStatus::Ok { final_val_loss: last, epochs_completed: num_epochs }
745    }
746
747    /// Accessors for test / CLI wiring.
748    pub fn step_metrics(&self) -> &[StepMetrics] {
749        &self.step_metrics
750    }
751
752    pub fn epoch_artifacts(&self) -> &[EpochArtifact] {
753        &self.epoch_artifacts
754    }
755
756    pub fn val_loss_history(&self) -> &[f32] {
757        &self.val_loss_history
758    }
759
760    /// INV-TRAIN-003 — sha256 of optimizer state. In the full driver this
761    /// hashes the AdamW m/v buffers; here the driver is model-agnostic,
762    /// so we derive a deterministic sha from epoch + step + config seed
763    /// to keep GATE-TRAIN-006 reproducible. Production wiring will
764    /// replace this with a real hash of the optimizer state bytes.
765    fn fake_optimizer_sha(&self, epoch: usize) -> String {
766        use sha2::{Digest, Sha256};
767        let mut hasher = Sha256::new();
768        hasher.update(b"aprender-train:pretrain:optstate:v1:");
769        hasher.update(self.config.seed.to_le_bytes());
770        hasher.update((epoch as u64).to_le_bytes());
771        hasher.update(self.tokens_seen.to_le_bytes());
772        format!("{:x}", hasher.finalize())
773    }
774}
775
776// ─────────────────────────────────────────────────────────────
777// Test helpers + unit tests
778// ─────────────────────────────────────────────────────────────
779
780/// Synthetic `StepFn` that drives train_loss down linearly — used by the
781/// positive path tests (INV-TRAIN-004, GATE-TRAIN-006).
782pub struct LinearDecaySynthetic {
783    pub start_loss: f32,
784    pub decay_per_step: f32,
785    pub grad_norm: f32,
786}
787
788impl StepFn for LinearDecaySynthetic {
789    fn step(&mut self, step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
790        let loss = (self.start_loss - self.decay_per_step * step as f32).max(1.0e-4);
791        (loss, self.grad_norm)
792    }
793}
794
795/// Synthetic `ValFn` that returns a fixed sequence of epoch val-losses.
796/// The falsification harness uses this to inject a doubling trace
797/// (e.g. `[3.5, 7.1]`) and prove `check_non_divergence` aborts.
798pub struct ScriptedVal {
799    pub sequence: Vec<f32>,
800}
801
802impl ValFn for ScriptedVal {
803    fn validate(&mut self, epoch: usize) -> f32 {
804        *self.sequence.get(epoch).unwrap_or(&f32::NAN)
805    }
806}
807
808/// NaN-injecting synthetic for INV-TRAIN-007 falsification.
809pub struct NanAtStepSynthetic {
810    pub nan_step: u64,
811}
812
813impl StepFn for NanAtStepSynthetic {
814    fn step(&mut self, step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
815        if step == self.nan_step {
816            return (f32::NAN, 1.0);
817        }
818        (1.0, 1.0)
819    }
820}
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825    use std::cell::RefCell;
826    use std::rc::Rc;
827    use tempfile::TempDir;
828
829    fn test_config(tmp: &Path) -> PretrainConfig {
830        PretrainConfig {
831            dataset_path: tmp.join("data.jsonl"),
832            tokenizer_dir: tmp.join("tok"),
833            run_dir: tmp.join("run"),
834            lr_max: 1.0e-4,
835            lr_min: 1.0e-6,
836            warmup_steps: 2,
837            total_steps: 25,
838            batch_size: 2,
839            seq_length: 4,
840            steps_per_epoch: 5,
841            seed: 42,
842            grad_clip: 1.0,
843            weight_decay: 0.01,
844            target_val_loss: 2.2,
845            patience_epochs: 2,
846            min_epochs_before_early_stop: 1,
847            regime: TrainingRegime::Finetune,
848        }
849    }
850
851    // ── GATE-TRAIN-005 falsifier — the MODEL-1 v2 ship-blocker ──
852
853    /// Spec `FALSIFY-SHIP-013` harness: inject a doubling val-loss trace
854    /// and assert `check_non_divergence` returns `Err(Divergence)`.
855    #[test]
856    fn gate_train_005_aborts_on_doubling_val_loss() {
857        let trace = vec![3.5, 7.1];
858        let res = check_non_divergence(1, &trace, &TrainingRegime::Finetune);
859        match res {
860            Err(PretrainAbort::Divergence { epoch, prev_val_loss, curr_val_loss, ratio }) => {
861                assert_eq!(epoch, 1);
862                assert!((prev_val_loss - 3.5).abs() < 1e-6);
863                assert!((curr_val_loss - 7.1).abs() < 1e-6);
864                assert!(ratio > 2.0);
865            }
866            other => panic!("GATE-TRAIN-005 did not abort: got {other:?}"),
867        }
868    }
869
870    /// Special case: val_loss[0] > 10.0 is the MODEL-1 v2 defect (val_loss
871    /// 31.99 at epoch 0). Must abort at epoch 0, without waiting for N=1.
872    #[test]
873    fn gate_train_005_aborts_on_epoch_zero_blowup() {
874        let trace = vec![31.99];
875        let res = check_non_divergence(0, &trace, &TrainingRegime::Finetune);
876        match res {
877            Err(PretrainAbort::DivergenceAtEpochZero { val_loss }) => {
878                assert!((val_loss - 31.99).abs() < 1e-4);
879            }
880            other => panic!("epoch-0 guard missed: got {other:?}"),
881        }
882    }
883
884    /// Healthy trace — must NOT abort. Lower ratios preserve training.
885    #[test]
886    fn gate_train_005_allows_healthy_decrease() {
887        let trace = vec![3.5, 3.0, 2.5, 2.2];
888        for epoch in 0..trace.len() {
889            assert!(check_non_divergence(epoch, &trace, &TrainingRegime::Finetune).is_ok());
890        }
891    }
892
893    /// Boundary case — ratio exactly 2.0 MUST be allowed (strict `>` in contract).
894    #[test]
895    fn gate_train_005_allows_exact_two_x() {
896        let trace = vec![2.0, 4.0];
897        assert!(check_non_divergence(1, &trace, &TrainingRegime::Finetune).is_ok());
898    }
899
900    // ── INV-TRAIN-005 v1.2.0 regime split — from_scratch epoch-zero cap ──
901
902    /// v1.2.0: from_scratch epoch-zero cap is 2·ln(vocab_size). For
903    /// vocab=50257 this is ≈21.64. `val_loss[0]=18.0` is below cap and
904    /// would trip the old 10.0 literal — must be allowed in the new regime.
905    #[test]
906    fn gate_train_005_from_scratch_permits_near_random_baseline() {
907        let trace = vec![18.0_f32];
908        let regime = TrainingRegime::FromScratch { vocab_size: 50_257 };
909        assert!(
910            check_non_divergence(0, &trace, &regime).is_ok(),
911            "val_loss[0]=18 must be within 2·ln(50257)≈21.64 from_scratch cap"
912        );
913        // Same trace under Finetune still aborts — regime split is not a weakening.
914        assert!(matches!(
915            check_non_divergence(0, &trace, &TrainingRegime::Finetune),
916            Err(PretrainAbort::DivergenceAtEpochZero { .. }),
917        ));
918    }
919
920    /// v1.2.0: from_scratch above 2·ln(vocab_size) still aborts. Confirms
921    /// the gate is not degenerate — it catches truly broken inits.
922    #[test]
923    fn gate_train_005_from_scratch_aborts_above_2_ln_vocab() {
924        let trace = vec![25.0_f32];
925        let regime = TrainingRegime::FromScratch { vocab_size: 50_257 };
926        match check_non_divergence(0, &trace, &regime) {
927            Err(PretrainAbort::DivergenceAtEpochZero { val_loss }) => {
928                assert!((val_loss - 25.0).abs() < 1e-4);
929            }
930            other => panic!("from_scratch cap missed: got {other:?}"),
931        }
932    }
933
934    /// v1.2.0: the computed cap matches the formula 2·ln(vocab_size).
935    /// Locks the threshold so a silent code change cannot relax it.
936    #[test]
937    fn training_regime_from_scratch_cap_matches_formula() {
938        let v = 50_257u32;
939        let regime = TrainingRegime::FromScratch { vocab_size: v };
940        let expected = DIVERGENCE_RATIO_LIMIT * (v as f32).ln();
941        assert!(
942            (regime.epoch_zero_val_loss_limit() - expected).abs() < 1e-4,
943            "cap formula drift: got {} expected {}",
944            regime.epoch_zero_val_loss_limit(),
945            expected
946        );
947        // Finetune stays on the MODEL-1 literal.
948        assert!(
949            (TrainingRegime::Finetune.epoch_zero_val_loss_limit() - EPOCH_ZERO_VAL_LOSS_LIMIT)
950                .abs()
951                < 1e-6
952        );
953    }
954
955    // ── GATE-TRAIN-007 falsifier — NaN poisoning ──
956
957    #[test]
958    fn gate_train_007_aborts_on_nan_train_loss() {
959        let res = check_numerical_stability(42, f32::NAN, 1.0);
960        match res {
961            Err(PretrainAbort::NumericalInstability { step, field, .. }) => {
962                assert_eq!(step, 42);
963                assert_eq!(field, "train_loss");
964            }
965            other => panic!("nan guard missed: got {other:?}"),
966        }
967    }
968
969    #[test]
970    fn gate_train_007_aborts_on_inf_grad_norm() {
971        let res = check_numerical_stability(7, 1.0, f32::INFINITY);
972        assert!(matches!(res, Err(PretrainAbort::NumericalInstability { .. })));
973    }
974
975    // ── GATE-TRAIN-001 / INV-TRAIN-008 — metrics validation ──
976
977    #[test]
978    fn step_metrics_validate_finite_accepts_healthy() {
979        let m = StepMetrics {
980            step: 0,
981            train_loss: 3.2,
982            grad_norm: 0.5,
983            lr: 1e-4,
984            tokens_per_sec: 1000.0,
985            gpu_util_pct: 75.0,
986            wall_ms: 5.0,
987        };
988        assert!(m.validate_finite().is_ok());
989    }
990
991    #[test]
992    fn step_metrics_rejects_negative_throughput() {
993        let m = StepMetrics {
994            step: 1,
995            train_loss: 3.2,
996            grad_norm: 0.5,
997            lr: 1e-4,
998            tokens_per_sec: -1.0,
999            gpu_util_pct: 75.0,
1000            wall_ms: 5.0,
1001        };
1002        assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
1003    }
1004
1005    #[test]
1006    fn step_metrics_rejects_gpu_util_over_100() {
1007        let m = StepMetrics {
1008            step: 1,
1009            train_loss: 3.2,
1010            grad_norm: 0.5,
1011            lr: 1e-4,
1012            tokens_per_sec: 1000.0,
1013            gpu_util_pct: 150.0,
1014            wall_ms: 5.0,
1015        };
1016        assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
1017    }
1018
1019    /// Per `contracts/training-loop-pretrain-v1.yaml` v1.5.0:
1020    /// wall_ms must be finite and non-negative.
1021    #[test]
1022    fn step_metrics_rejects_negative_wall_ms() {
1023        let m = StepMetrics {
1024            step: 1,
1025            train_loss: 3.2,
1026            grad_norm: 0.5,
1027            lr: 1e-4,
1028            tokens_per_sec: 1000.0,
1029            gpu_util_pct: 75.0,
1030            wall_ms: -1.0,
1031        };
1032        assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
1033    }
1034
1035    /// wall_ms must be finite (NaN/Inf rejected).
1036    #[test]
1037    fn step_metrics_rejects_nan_wall_ms() {
1038        let m = StepMetrics {
1039            step: 1,
1040            train_loss: 3.2,
1041            grad_norm: 0.5,
1042            lr: 1e-4,
1043            tokens_per_sec: 1000.0,
1044            gpu_util_pct: 75.0,
1045            wall_ms: f32::NAN,
1046        };
1047        assert!(matches!(m.validate_finite(), Err(PretrainAbort::ThroughputOutOfRange { .. })));
1048    }
1049
1050    /// Consistency invariant from contract v1.5.0:
1051    /// `tokens_per_sec * (wall_ms / 1000.0) ≈ batch_tokens` within FP
1052    /// rounding. Both metrics derive from the same `Instant::now()`
1053    /// span so they cannot drift independently.
1054    #[test]
1055    fn step_metrics_wall_ms_consistent_with_tokens_per_sec() {
1056        let batch_tokens: u64 = 1024;
1057        let elapsed_secs: f32 = 0.5;
1058        let tokens_per_sec = batch_tokens as f32 / elapsed_secs;
1059        let wall_ms = elapsed_secs * 1000.0;
1060
1061        let m = StepMetrics {
1062            step: 0,
1063            train_loss: 3.2,
1064            grad_norm: 0.5,
1065            lr: 1e-4,
1066            tokens_per_sec,
1067            gpu_util_pct: 50.0,
1068            wall_ms,
1069        };
1070        assert!(m.validate_finite().is_ok());
1071        let derived_tokens = m.tokens_per_sec * (m.wall_ms / 1000.0);
1072        let diff = (derived_tokens - batch_tokens as f32).abs();
1073        assert!(
1074            diff < 0.5,
1075            "tokens_per_sec * (wall_ms/1000) = {derived_tokens} should equal batch_tokens={batch_tokens} within FP rounding"
1076        );
1077    }
1078
1079    // ── PretrainLoop — driver-level falsifications ──
1080
1081    #[test]
1082    fn pretrain_loop_happy_path_decreasing_loss() {
1083        let tmp = TempDir::new().expect("tempdir");
1084        let cfg = test_config(tmp.path());
1085        let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1086        let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
1087        let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1088
1089        let status = loop_.run();
1090        match status {
1091            RunStatus::Ok { final_val_loss, epochs_completed } => {
1092                assert!(final_val_loss <= 2.2);
1093                assert!(epochs_completed >= 1);
1094            }
1095            other => panic!("healthy run did not converge cleanly: {other:?}"),
1096        }
1097
1098        // GATE-TRAIN-001 — every step recorded all 6 fields.
1099        assert!(!loop_.step_metrics().is_empty());
1100        for m in loop_.step_metrics() {
1101            assert!(m.train_loss.is_finite());
1102            assert!(m.grad_norm.is_finite());
1103            assert!(m.lr.is_finite());
1104            assert!(m.tokens_per_sec >= 0.0);
1105            assert!((0.0..=100.0).contains(&m.gpu_util_pct));
1106        }
1107        // GATE-TRAIN-002 — one metadata per completed epoch.
1108        assert_eq!(loop_.epoch_artifacts().len(), loop_.val_loss_history().len());
1109        for art in loop_.epoch_artifacts() {
1110            assert!(!art.metadata.optimizer_state_sha.is_empty());
1111            assert!(art.metadata.train_ppl.is_finite());
1112            assert!(art.metadata.val_ppl.is_finite());
1113        }
1114    }
1115
1116    /// INV-TRAIN-005 ship-blocker end-to-end: drive a doubling val-loss
1117    /// through the full `run_epoch` and prove the loop aborts, not a
1118    /// post-hoc audit. This is the falsifier GATE-TRAIN-005 mandates.
1119    #[test]
1120    fn pretrain_loop_aborts_on_doubling_val_loss() {
1121        let tmp = TempDir::new().expect("tempdir");
1122        let cfg = test_config(tmp.path());
1123        let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1124        // val_loss doubles between epochs 0 and 1 — must abort.
1125        let val_fn = ScriptedVal { sequence: vec![3.5, 7.1, 2.0] };
1126        let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1127
1128        let status = loop_.run();
1129        match status {
1130            RunStatus::Aborted(PretrainAbort::Divergence { epoch, ratio, .. }) => {
1131                assert_eq!(epoch, 1);
1132                assert!(ratio > 2.0);
1133            }
1134            other => panic!("GATE-TRAIN-005 did not fire: {other:?}"),
1135        }
1136    }
1137
1138    /// INV-TRAIN-007 end-to-end: NaN in train_loss at step N aborts.
1139    #[test]
1140    fn pretrain_loop_aborts_on_nan_in_train_loss() {
1141        let tmp = TempDir::new().expect("tempdir");
1142        let cfg = test_config(tmp.path());
1143        let step_fn = NanAtStepSynthetic { nan_step: 3 };
1144        let val_fn = ScriptedVal { sequence: vec![3.0] };
1145        let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1146
1147        let status = loop_.run();
1148        match status {
1149            RunStatus::Aborted(PretrainAbort::NumericalInstability { step, field, .. }) => {
1150                assert_eq!(step, 3);
1151                assert_eq!(field, "train_loss");
1152            }
1153            other => panic!("INV-TRAIN-007 did not fire: {other:?}"),
1154        }
1155    }
1156
1157    /// INV-TRAIN-006: two runs with the same seed produce identical metrics
1158    /// for the first 100 steps. We use 10 steps here to keep the unit test
1159    /// fast; CI GATE-TRAIN-006 runs the full 100.
1160    #[test]
1161    fn pretrain_loop_reproducibility_seed_42() {
1162        let tmp1 = TempDir::new().expect("tempdir1");
1163        let tmp2 = TempDir::new().expect("tempdir2");
1164        let cfg1 = test_config(tmp1.path());
1165        let cfg2 = test_config(tmp2.path());
1166
1167        let step_fn1 =
1168            LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1169        let step_fn2 =
1170            LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1171        let val_fn1 = ScriptedVal { sequence: vec![3.0, 2.8, 2.6, 2.4, 2.2] };
1172        let val_fn2 = ScriptedVal { sequence: vec![3.0, 2.8, 2.6, 2.4, 2.2] };
1173
1174        let mut loop1 = PretrainLoop::new(cfg1, step_fn1, val_fn1);
1175        let mut loop2 = PretrainLoop::new(cfg2, step_fn2, val_fn2);
1176        let _ = loop1.run();
1177        let _ = loop2.run();
1178
1179        assert_eq!(loop1.step_metrics().len(), loop2.step_metrics().len());
1180        for (a, b) in loop1.step_metrics().iter().zip(loop2.step_metrics().iter()) {
1181            assert_eq!(a.step, b.step);
1182            assert!((a.train_loss - b.train_loss).abs() < 1e-6);
1183            assert!((a.grad_norm - b.grad_norm).abs() < 1e-6);
1184            assert!((a.lr - b.lr).abs() < 1e-6);
1185            // gpu_util_pct is RNG-driven; seed-matched ⇒ byte-identical.
1186            assert!((a.gpu_util_pct - b.gpu_util_pct).abs() < 1e-6);
1187        }
1188    }
1189
1190    /// `lr_at` must match WarmupCosineDecayLR behavior byte-for-byte at
1191    /// the boundary points (start of warmup, end of warmup, end of decay).
1192    #[test]
1193    fn lr_schedule_warmup_cosine_boundaries() {
1194        let tmp = TempDir::new().expect("tempdir");
1195        let cfg = PretrainConfig {
1196            warmup_steps: 10,
1197            total_steps: 100,
1198            lr_max: 1.0e-3,
1199            lr_min: 1.0e-5,
1200            ..test_config(tmp.path())
1201        };
1202        let step_fn = LinearDecaySynthetic { start_loss: 1.0, decay_per_step: 0.0, grad_norm: 0.1 };
1203        let val_fn = ScriptedVal { sequence: vec![1.0] };
1204        let loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1205
1206        // Start of warmup — lr should be 0.
1207        assert!((loop_.lr_at(0) - 0.0).abs() < 1e-9);
1208        // End of warmup — lr at peak.
1209        assert!((loop_.lr_at(10) - 1.0e-3).abs() < 1e-6);
1210        // End of decay — lr at minimum.
1211        assert!((loop_.lr_at(100) - 1.0e-5).abs() < 1e-6);
1212    }
1213
1214    /// Artifact paths must match the contract's `path_template`.
1215    #[test]
1216    fn epoch_artifact_paths_match_contract_template() {
1217        let tmp = TempDir::new().expect("tempdir");
1218        let run_dir = tmp.path().join("run");
1219        let metadata = EpochMetadata {
1220            epoch: 7,
1221            train_loss: 3.0,
1222            val_loss: 2.8,
1223            train_ppl: 20.0,
1224            val_ppl: 16.4,
1225            optimizer_state_sha: "deadbeef".into(),
1226            wall_seconds: 42.0,
1227            tokens_seen: 1_000_000,
1228            grad_norm_max: 1.5,
1229        };
1230        let art = EpochArtifact::new(&run_dir, 7, metadata);
1231        assert!(art.checkpoint_path.ends_with("ckpt/epoch-007.apr"));
1232        assert!(art.metadata_path.ends_with("ckpt/epoch-007.metadata.json"));
1233    }
1234
1235    // ── Task #111 step 7 — CheckpointFn hook falsifiers ──
1236
1237    /// Mock `CheckpointFn` that records every (epoch, checkpoint_path) pair
1238    /// so tests can assert the loop invokes the hook exactly once per
1239    /// passing epoch and never on an aborted epoch.
1240    struct RecordingCheckpointFn {
1241        calls: Rc<RefCell<Vec<(usize, PathBuf)>>>,
1242    }
1243
1244    impl CheckpointFn for RecordingCheckpointFn {
1245        fn save(&mut self, epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
1246            self.calls.borrow_mut().push((epoch, artifact.checkpoint_path.clone()));
1247            Ok(())
1248        }
1249    }
1250
1251    /// INV-TRAIN-005 positive: one checkpoint call per epoch that
1252    /// passes GATE-TRAIN-005, with metadata.json emitted alongside.
1253    #[test]
1254    fn pretrain_loop_calls_checkpoint_fn_once_per_passing_epoch() {
1255        let tmp = TempDir::new().expect("tempdir");
1256        let cfg = test_config(tmp.path());
1257        let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1258        let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
1259        let calls: Rc<RefCell<Vec<(usize, PathBuf)>>> = Rc::new(RefCell::new(Vec::new()));
1260        let ckpt = RecordingCheckpointFn { calls: Rc::clone(&calls) };
1261
1262        let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn).with_checkpoint_fn(Box::new(ckpt));
1263        let _status = loop_.run();
1264
1265        let recorded = calls.borrow();
1266        let epoch_count = loop_.epoch_artifacts().len();
1267        assert!(epoch_count >= 1, "at least one epoch should have completed");
1268        assert_eq!(
1269            recorded.len(),
1270            epoch_count,
1271            "CheckpointFn must fire exactly once per epoch that passes GATE-TRAIN-005",
1272        );
1273        for (i, (epoch, path)) in recorded.iter().enumerate() {
1274            assert_eq!(*epoch, i, "checkpoint hook epoch indices must be monotonic from 0");
1275            assert!(
1276                path.to_string_lossy().contains(&format!("epoch-{:03}.apr", epoch)),
1277                "checkpoint path must match contract template: {:?}",
1278                path,
1279            );
1280            let meta_path = path.with_extension("metadata.json");
1281            assert!(
1282                meta_path.exists(),
1283                "companion metadata.json must be written for epoch {}",
1284                epoch,
1285            );
1286        }
1287    }
1288
1289    /// INV-TRAIN-003 positive: if the StepFn overrides
1290    /// `optimizer_state_sha256`, the loop uses it instead of the
1291    /// synthetic-seed fallback. Asserts that the recorded epoch
1292    /// metadata carries the sha from the StepFn.
1293    #[test]
1294    fn pretrain_loop_uses_step_fn_optimizer_sha_when_available() {
1295        struct ShaOverride {
1296            inner: LinearDecaySynthetic,
1297            sha: String,
1298        }
1299        impl StepFn for ShaOverride {
1300            fn step(&mut self, s: u64, lr: f32, tokens: u64) -> (f32, f32) {
1301                self.inner.step(s, lr, tokens)
1302            }
1303            fn optimizer_state_sha256(&self) -> Option<String> {
1304                Some(self.sha.clone())
1305            }
1306        }
1307
1308        let tmp = TempDir::new().expect("tempdir");
1309        let cfg = test_config(tmp.path());
1310        let step_fn = ShaOverride {
1311            inner: LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 },
1312            sha: "a".repeat(64),
1313        };
1314        let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
1315        let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1316        let _ = loop_.run();
1317
1318        let arts = loop_.epoch_artifacts();
1319        assert!(!arts.is_empty(), "at least one epoch should have completed");
1320        for art in arts {
1321            assert_eq!(
1322                art.metadata.optimizer_state_sha,
1323                "a".repeat(64),
1324                "StepFn override must win over fake_optimizer_sha fallback",
1325            );
1326        }
1327    }
1328
1329    /// INV-TRAIN-003 fallback: a synthetic StepFn that does not
1330    /// override `optimizer_state_sha256` still gets a non-empty,
1331    /// deterministic 64-char digest via the `fake_optimizer_sha`
1332    /// fingerprint. (The default impl returns `None`.)
1333    #[test]
1334    fn pretrain_loop_falls_back_to_fake_optimizer_sha_for_synthetic() {
1335        let tmp = TempDir::new().expect("tempdir");
1336        let cfg = test_config(tmp.path());
1337        let step_fn = LinearDecaySynthetic { start_loss: 3.5, decay_per_step: 0.1, grad_norm: 0.8 };
1338        let val_fn = ScriptedVal { sequence: vec![3.4, 3.0, 2.6, 2.2, 2.0] };
1339        let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn);
1340        let _ = loop_.run();
1341
1342        for art in loop_.epoch_artifacts() {
1343            assert_eq!(
1344                art.metadata.optimizer_state_sha.len(),
1345                64,
1346                "fallback fingerprint must still be a 64-char hex digest",
1347            );
1348            assert!(
1349                art.metadata.optimizer_state_sha.chars().all(|c| c.is_ascii_hexdigit()),
1350                "fallback fingerprint must be lowercase hex",
1351            );
1352        }
1353    }
1354
1355    /// INV-TRAIN-007 negative: NaN in train_loss aborts the loop, and
1356    /// the checkpoint hook must NOT fire for the aborted epoch.
1357    #[test]
1358    fn pretrain_loop_skips_checkpoint_on_abort() {
1359        let tmp = TempDir::new().expect("tempdir");
1360        let cfg = test_config(tmp.path());
1361        let step_fn = NanAtStepSynthetic { nan_step: 1 };
1362        let val_fn = ScriptedVal { sequence: vec![3.0] };
1363        let calls: Rc<RefCell<Vec<(usize, PathBuf)>>> = Rc::new(RefCell::new(Vec::new()));
1364        let ckpt = RecordingCheckpointFn { calls: Rc::clone(&calls) };
1365
1366        let mut loop_ = PretrainLoop::new(cfg, step_fn, val_fn).with_checkpoint_fn(Box::new(ckpt));
1367        let status = loop_.run();
1368
1369        assert!(
1370            matches!(status, RunStatus::Aborted(PretrainAbort::NumericalInstability { .. })),
1371            "NaN must abort the loop: got {status:?}",
1372        );
1373        assert!(
1374            calls.borrow().is_empty(),
1375            "CheckpointFn must NOT fire when the epoch aborts before GATE-TRAIN-005 passes",
1376        );
1377    }
1378}