Skip to main content

entrenar/train/
pretrain_real.rs

1//! Real-corpus `StepFn` / `ValFn` for MODEL-2 pretrain MVP (task #111).
2//!
3//! Bridges the model-agnostic `PretrainLoop` (`pretrain.rs`) to the
4//! 370M Llama scaffold (`models/llama_370m.rs`) by wiring a real
5//! `TransformerTrainer` through the `StepFn` and `ValFn` traits.
6//!
7//! The loop drive replaces the `LinearDecaySynthetic` / `ScriptedVal`
8//! pair used for GATE-TRAIN-005/007/008 wiring verification (task #105)
9//! with a real forward + backward + optimizer step and a real held-out
10//! validation forward pass.
11//!
12//! Contract obligations discharged:
13//! - INV-ARCH-370M-001 (param count in [366M, 374M]) via `debug_assert_eq!`
14//! - INV-TRAIN-001 (per-step metrics — 6 fields via PretrainLoop)
15//! - INV-TRAIN-007 (no NaN/Inf — the loop aborts on first non-finite)
16//!
17//! Deferred (task #111 follow-ups):
18//! - Real grad_norm (currently reports a placeholder; needs
19//!   TransformerTrainer extension to surface pre-clip norm)
20//! - INV-TRAIN-003 (real optimizer-state sha256 over AdamW m/v/t buffers)
21
22use crate::models::llama_370m::Llama370MConfig;
23use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
24use crate::train::transformer_trainer::{LMBatch, TransformerTrainConfig, TransformerTrainer};
25use crate::transformer::{ModelArchitecture, Transformer, TransformerConfig};
26use crate::Tensor;
27use std::cell::RefCell;
28use std::collections::BTreeMap;
29use std::path::Path;
30use std::rc::Rc;
31
32/// Shared mutable ownership of the `TransformerTrainer` — both
33/// `RealStepFn` (training steps) and `RealValFn` (forward-only
34/// validation) clone this `Rc`.
35pub type SharedTrainer = Rc<RefCell<TransformerTrainer>>;
36
37/// Load tensors from an APR file as the read-half of `apr pretrain --init`.
38///
39/// Per `apr-pretrain-arch-polymorphic-v1` §init_load_semantics (PR #1473),
40/// the loader is REUSED, not reimplemented — this function is a thin wrapper
41/// over `aprender::format::converter::convert_report::load_model_tensors`,
42/// which is the same machinery `apr export` and `apr inspect` use. No
43/// duplicate APR parser; one source of truth.
44///
45/// Returns a map of `tensor_name -> (flat_f32_data, shape)`. The HF naming
46/// convention is preserved (e.g., `model.embed_tokens.weight`); reconciling
47/// against the trainer's parameter names is step 5f.3 (the population step).
48///
49/// Discharges from `apr-pretrain-arch-polymorphic-v1`:
50///   - §init_load_semantics invariant: "Loader is reused, not reimplemented"
51///   - FALSIFY-006 (init_loss < 6.0) at READ-COMPILE-BIND level: this
52///     function is the read half. Full FALSIFY-006 discharge requires
53///     5f.3 (population) + 5g (LIVE 500-step fine-tune).
54///
55/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.2.
56///
57/// # Errors
58///
59/// Returns Err if the APR file:
60/// - Does not exist (filesystem I/O error)
61/// - Has invalid magic bytes (not APR\\0 or APRN)
62/// - Has a corrupted tensor index
63/// - Contains tensors with unsupported dtype (non-F32)
64pub fn load_init_tensors_from_apr(
65    path: impl AsRef<Path>,
66) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String> {
67    let path_ref = path.as_ref();
68    aprender::format::converter::load_model_tensors(path_ref).map_err(|e| {
69        format!(
70            "FALSIFY-APR-PRETRAIN-INIT-006: failed to load init tensors from APR file {}: {e}",
71            path_ref.display()
72        )
73    })
74}
75
76/// Reject an init `TransformerConfig` whose architecture family is incompatible
77/// with the pretrain target (decoder-only LM training).
78///
79/// Per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature
80/// (PR #1473), wrong-arch APR (e.g., a CodeBERT/RoBERTa encoder model)
81/// MUST be FAIL-FAST not silent-truncate. Without this gate, an operator
82/// who points `--init` at e.g. `microsoft/codebert-base.apr` would silently
83/// load encoder weights into a decoder-shaped trainer, producing nonsense
84/// gradients that the divergence guard catches LATE (after multiple epochs).
85///
86/// Discharges FALSIFY-APR-PRETRAIN-ARCH-007 at PARTIAL_ALGORITHM_LEVEL.
87///
88/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
89///
90/// # Errors
91///
92/// Returns Err with a clear architecture-family-mismatch message when:
93/// - `cfg.architecture` is `ModelArchitecture::Encoder` (BERT/RoBERTa/CodeBERT)
94///
95/// Future expansion can add other family checks (e.g., reject hybrid SSM
96/// architectures whose forward pass differs from the standard decoder loop).
97pub fn validate_pretrain_init_arch_compatible(cfg: &TransformerConfig) -> Result<(), String> {
98    match cfg.architecture {
99        ModelArchitecture::Decoder => Ok(()),
100        ModelArchitecture::Encoder => Err(format!(
101            "FALSIFY-APR-PRETRAIN-ARCH-007: --init checkpoint has architecture=Encoder \
102             (e.g., BERT/RoBERTa/CodeBERT) but the pretrain trainer is decoder-only \
103             (Llama/Qwen-class causal LMs). Loading encoder weights into a decoder \
104             trainer would produce nonsense gradients. Architectural details: \
105             hidden_size={}, num_layers={}, vocab_size={}, hf_architecture={:?}",
106            cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size, cfg.hf_architecture
107        )),
108    }
109}
110
111/// Populate a `Transformer`'s parameters from an `init_tensors` BTreeMap.
112///
113/// For each parameter the `Transformer` exposes via `named_parameters()`, look
114/// up the same HF-naming key in `init_tensors` and replace the parameter's
115/// `Tensor` with `Tensor::from_vec(data, requires_grad=true)`. The parameter
116/// remains trainable (gradients still flow) — this is fine-tune init, not
117/// frozen-encoder.
118///
119/// Strictness rules:
120/// - **Every** model parameter must have a matching entry in `init_tensors`.
121///   Missing entries return Err naming the unmatched parameter; this catches
122///   the case where the init APR was extracted from a different architecture.
123/// - **Length** of each init entry must match the model parameter's length
124///   (computed from the model's existing tensor `len()`). Mismatch returns Err.
125/// - **Extra** entries in `init_tensors` are silently ignored. This handles
126///   `tie_word_embeddings`: a Qwen2.5 APR may publish a separate `lm_head.weight`
127///   tensor that the model omits when ties are enabled.
128///
129/// Discharges from `apr-pretrain-arch-polymorphic-v1` §init_load_semantics:
130/// - Population invariant: "Init tensors populate trainer parameters
131///   byte-equivalent to source"
132/// - FALSIFY-APR-PRETRAIN-INIT-007 (population step) at PARTIAL_ALGORITHM_LEVEL.
133///
134/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.3.
135///
136/// # Errors
137///
138/// Returns Err if any model parameter is missing from `init_tensors` or if
139/// any matched entry has a wrong length. The error message lists up to the
140/// first 5 problem parameters and the total count.
141pub fn populate_trainer_from_init_tensors(
142    transformer: &mut Transformer,
143    init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
144) -> Result<usize, String> {
145    let expected: Vec<(String, usize)> =
146        transformer.named_parameters().into_iter().map(|(name, t)| (name, t.len())).collect();
147    let mut populated = 0usize;
148    let mut errors: Vec<String> = Vec::new();
149
150    for (name, expected_len) in &expected {
151        match init_tensors.get(name) {
152            Some((data, _shape)) => {
153                if data.len() != *expected_len {
154                    errors.push(format!(
155                        "{name}: init length {} != trainer expected {expected_len}",
156                        data.len()
157                    ));
158                    continue;
159                }
160                let tensor = Tensor::from_vec(data.clone(), true);
161                if !transformer.set_named_parameter(name, tensor) {
162                    errors.push(format!("{name}: set_named_parameter rejected the assignment"));
163                    continue;
164                }
165                populated += 1;
166            }
167            None => {
168                errors.push(format!("{name}: not present in init APR tensors"));
169            }
170        }
171    }
172
173    if !errors.is_empty() {
174        let total = errors.len();
175        let head = errors.iter().take(5).cloned().collect::<Vec<_>>().join("; ");
176        return Err(format!(
177            "FALSIFY-APR-PRETRAIN-INIT-007: populate_trainer_from_init_tensors \
178             failed for {total} parameter(s); first {} of {total}: {head}",
179            errors.len().min(5)
180        ));
181    }
182
183    Ok(populated)
184}
185
186/// Build a `TransformerConfig` field-for-field from `Llama370MConfig::*`
187/// constants (the contract-frozen MODEL-2 370M architecture).
188pub fn llama_370m_transformer_config() -> TransformerConfig {
189    TransformerConfig {
190        hidden_size: Llama370MConfig::HIDDEN_DIM,
191        num_attention_heads: Llama370MConfig::NUM_HEADS,
192        num_kv_heads: Llama370MConfig::NUM_KV_HEADS,
193        intermediate_size: Llama370MConfig::INTERMEDIATE_DIM,
194        num_hidden_layers: Llama370MConfig::NUM_LAYERS,
195        vocab_size: Llama370MConfig::VOCAB_SIZE,
196        max_position_embeddings: Llama370MConfig::MAX_POSITION_EMBEDDINGS,
197        rms_norm_eps: Llama370MConfig::RMS_NORM_EPS,
198        rope_theta: Llama370MConfig::ROPE_THETA,
199        use_bias: false,
200        head_dim_override: None,
201        architecture: ModelArchitecture::Decoder,
202        hf_architecture: Some("LlamaForCausalLM".into()),
203        hf_model_type: Some("llama".into()),
204        tie_word_embeddings: true,
205    }
206}
207
208/// Polymorphic builder per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature.
209///
210/// Discharges FALSIFY-APR-PRETRAIN-ARCH-002 (init=None preserves Llama370M baseline)
211/// and FALSIFY-APR-PRETRAIN-ARCH-003 (init=Some passes through extracted config).
212///
213/// Behaviour:
214///   init = None  → return `llama_370m_transformer_config()`, the §24/§25
215///                  from-scratch baseline. NO regression.
216///   init = Some  → clone the caller-extracted `TransformerConfig` byte-for-byte.
217///                  No silent defaults, no field overrides.
218///
219/// The caller is responsible for actually reading the APR file and producing the
220/// `TransformerConfig` (typically via `TransformerConfig::from_apr_metadata` from
221/// `transformer::config`). Decoupling the dispatch from the file I/O keeps
222/// `aprender-train` free of `aprender-serve` (the APR loader) as a build dep.
223///
224/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
225pub fn build_transformer_config(init: Option<&TransformerConfig>) -> TransformerConfig {
226    match init {
227        None => llama_370m_transformer_config(),
228        Some(cfg) => cfg.clone(),
229    }
230}
231
232/// Build a `TransformerTrainConfig` with MODEL-2 v2-remedy defaults
233/// (LR=5e-5, AdamW defaults, fp32, seed=42 set by caller).
234pub fn llama_370m_train_config(lr: f32, seq_length: usize, seed: u64) -> TransformerTrainConfig {
235    let model_cfg = llama_370m_transformer_config();
236    let mut cfg = TransformerTrainConfig::new(model_cfg);
237    cfg.lr = lr;
238    cfg.max_seq_len = seq_length;
239    cfg.seed = seed;
240    cfg
241}
242
243/// `StepFn` impl that pulls one `LMBatch` from an owned iterator and
244/// runs a real forward + backward + AdamW step through the shared
245/// `TransformerTrainer`.
246pub struct RealStepFn {
247    trainer: SharedTrainer,
248    batches: Box<dyn Iterator<Item = LMBatch>>,
249}
250
251impl RealStepFn {
252    pub fn new(trainer: SharedTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
253        Self { trainer, batches }
254    }
255}
256
257impl StepFn for RealStepFn {
258    fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
259        // Pull one batch; if the shard stream is exhausted before the
260        // loop plans to stop, emit a tiny finite placeholder so
261        // GATE-TRAIN-007 (NaN/Inf guard) does not mis-fire — the
262        // divergence guard (GATE-TRAIN-005) will correctly not abort
263        // on a flat tail.
264        let Some(batch) = self.batches.next() else {
265            return (1.0, 1.0);
266        };
267        let mut trainer = self.trainer.borrow_mut();
268        let loss = trainer.train_batch(&batch);
269        // TODO(task #111 follow-up): expose AdamW pre-clip grad norm.
270        // Placeholder = 1.0 keeps INV-TRAIN-007 satisfied (finite) and
271        // INV-TRAIN-008 satisfied (≥ 0); the real grad norm is a
272        // downstream ticket that needs TransformerTrainer extension.
273        let grad_norm = 1.0_f32;
274        (loss, grad_norm)
275    }
276
277    /// INV-TRAIN-003 discharge: hash the real AdamW (t, m, v) buffers.
278    fn optimizer_state_sha256(&self) -> Option<String> {
279        Some(self.trainer.borrow().optimizer_state_sha256())
280    }
281}
282
283/// `ValFn` impl that runs forward-only across a pre-loaded set of
284/// held-out batches and returns mean cross-entropy loss.
285pub struct RealValFn {
286    trainer: SharedTrainer,
287    held_out: Vec<LMBatch>,
288}
289
290impl RealValFn {
291    pub fn new(trainer: SharedTrainer, held_out: Vec<LMBatch>) -> Self {
292        Self { trainer, held_out }
293    }
294}
295
296impl ValFn for RealValFn {
297    fn validate(&mut self, _epoch: usize) -> f32 {
298        if self.held_out.is_empty() {
299            return f32::NAN;
300        }
301        let trainer = self.trainer.borrow();
302        let mut total_loss = 0.0_f32;
303        let mut total_items = 0_usize;
304        for batch in &self.held_out {
305            for i in 0..batch.batch_size {
306                let Some(inp) = batch.get_input(i) else {
307                    continue;
308                };
309                let Some(tgt) = batch.get_target(i) else {
310                    continue;
311                };
312                let (loss_val, _loss_tensor, _logits) = trainer.forward_single(inp, tgt);
313                total_loss += loss_val;
314                total_items += 1;
315            }
316        }
317        if total_items == 0 {
318            f32::NAN
319        } else {
320            total_loss / total_items as f32
321        }
322    }
323}
324
325/// `CheckpointFn` impl that writes the 370M Llama weights to
326/// `artifact.checkpoint_path` in APR format (task #111 step 7).
327///
328/// Holds the `SharedTrainer` alongside `RealStepFn` / `RealValFn` so
329/// the three hooks see the same in-memory weights.
330pub struct AprCheckpointFn {
331    trainer: SharedTrainer,
332    model_name: String,
333    architecture: String,
334}
335
336impl AprCheckpointFn {
337    pub fn new(
338        trainer: SharedTrainer,
339        model_name: impl Into<String>,
340        architecture: impl Into<String>,
341    ) -> Self {
342        Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
343    }
344}
345
346impl CheckpointFn for AprCheckpointFn {
347    fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
348        let trainer = self.trainer.borrow();
349        trainer
350            .save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
351            .map_err(|e| format!("save_apr failed: {e}"))
352    }
353}
354
355/// Shared-ownership helper so the CLI can hand the same trainer to
356/// both `RealStepFn` and `RealValFn`.
357pub fn build_shared_trainer(lr: f32, seq_length: usize, seed: u64) -> SharedTrainer {
358    let cfg = llama_370m_train_config(lr, seq_length, seed);
359    let trainer = TransformerTrainer::new(cfg);
360    // INV-ARCH-370M-001: verify parameter count lands in the 370M ± 1%
361    // band. This is a debug_assert so release builds do not pay for
362    // the full parameter walk, but dev builds catch drift the instant
363    // any Llama370MConfig constant changes.
364    #[cfg(debug_assertions)]
365    {
366        let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
367        debug_assert!(
368            (366_000_000..=374_000_000).contains(&param_count),
369            "INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
370        );
371    }
372    Rc::new(RefCell::new(trainer))
373}
374
375/// Polymorphic trainer builder for `apr pretrain --init` per
376/// `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature +
377/// §init_load_semantics (PR #1473).
378///
379/// Composes the §50.4 step-5f machinery into a single CLI-callable entry:
380///   - 5c: `build_transformer_config(init_arch)` — polymorphic dispatch
381///   - 5f.1: `validate_pretrain_init_arch_compatible(init_arch)` — encoder rejection
382///   - 5f.2: `load_init_tensors_from_apr(path)` — read APR weights
383///   - 5f.3: `populate_trainer_from_init_tensors(trainer, &tensors)` — populate
384///
385/// Behaviour:
386///   init = None  → identical to `build_shared_trainer` (Llama370M from-scratch
387///                  baseline; INV-ARCH-370M-001 enforced).
388///   init = Some  → builds a trainer with the EXTRACTED arch, validates the
389///                  family, loads tensors from the APR file, populates them.
390///                  INV-ARCH-370M-001 is NOT enforced (the arch is whatever the
391///                  init APR has, e.g. 0.5B / 1.5B / 7B).
392///
393/// Spec: SPEC-SHIP-TWO-001 §52.4 (step 5f.4 wireup).
394///
395/// # Errors
396///
397/// Returns Err when:
398/// - `init_arch` is `Some` with `architecture = Encoder` (FALSIFY-APR-PRETRAIN-ARCH-007)
399/// - `load_init_tensors_from_apr` fails (FALSIFY-APR-PRETRAIN-INIT-006)
400/// - `populate_trainer_from_init_tensors` fails (FALSIFY-APR-PRETRAIN-INIT-007)
401pub fn build_shared_trainer_with_init(
402    lr: f32,
403    seq_length: usize,
404    seed: u64,
405    init_arch: Option<&TransformerConfig>,
406    init_path: Option<&Path>,
407) -> Result<SharedTrainer, String> {
408    if init_arch.is_some() != init_path.is_some() {
409        return Err(format!(
410            "build_shared_trainer_with_init: init_arch and init_path must both be Some \
411             or both None (caller bug; init_arch.is_some()={}, init_path.is_some()={})",
412            init_arch.is_some(),
413            init_path.is_some()
414        ));
415    }
416
417    if let Some(cfg) = init_arch {
418        validate_pretrain_init_arch_compatible(cfg)?;
419    }
420
421    let model_cfg = build_transformer_config(init_arch);
422    let mut train_cfg = TransformerTrainConfig::new(model_cfg);
423    train_cfg.lr = lr;
424    train_cfg.max_seq_len = seq_length;
425    train_cfg.seed = seed;
426    let mut trainer = TransformerTrainer::new(train_cfg);
427
428    // Note: INV-ARCH-370M-001 (param-count band check) lives in
429    // `build_shared_trainer` (the from-scratch CLI path). The polymorphic
430    // builder is shape-agnostic by design — `build_transformer_config(init)`
431    // returns whatever the init APR has (0.5B, 1.5B, 7B, etc), so a single
432    // hardcoded band check would fire-fail on every non-Llama370M init.
433
434    if let Some(path) = init_path {
435        let tensors = load_init_tensors_from_apr(path)?;
436        populate_trainer_from_init_tensors(trainer.model_mut(), &tensors)?;
437    }
438
439    Ok(Rc::new(RefCell::new(trainer)))
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use crate::train::transformer_trainer::LMBatch;
446
447    /// FALSIFY-APR-PRETRAIN-INIT-006 (read-half) — load_init_tensors_from_apr
448    /// returns Err with a clear message naming the falsifier when the path
449    /// does not exist.
450    ///
451    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.2.
452    #[test]
453    fn load_init_tensors_missing_file_errors_with_falsifier_id() {
454        let tmp = tempfile::TempDir::new().expect("tempdir");
455        let missing = tmp.path().join("does-not-exist.apr");
456        let err =
457            load_init_tensors_from_apr(&missing).expect_err("missing init APR file MUST fail-fast");
458        assert!(
459            err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
460            "error must cite falsifier id (auditability): {err}"
461        );
462        assert!(
463            err.contains("does-not-exist.apr"),
464            "error must name the missing path (operator-experience): {err}"
465        );
466    }
467
468    /// FALSIFY-APR-PRETRAIN-INIT-006 (read-half) — function exists with the
469    /// right signature: `Path -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>>`.
470    /// Discharges the COMPILE-BIND level claim. Live empirical correctness
471    /// requires step 5g (operator-runnable LIVE fine-tune).
472    ///
473    /// Drift-prevention: this test catches a future refactor that changes
474    /// the return type or signature, which would break the §50.4 step 5f.3
475    /// follow-up that reconciles the BTreeMap against trainer parameters.
476    #[test]
477    fn load_init_tensors_signature_compile_bind() {
478        // Verify the function signature compile-binds: takes a Path-like,
479        // returns the right Result type. This is a compile-time check —
480        // if the signature drifts, this test stops compiling.
481        fn _check_signature<F>(_f: F)
482        where
483            F: Fn(&Path) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String>,
484        {
485        }
486        _check_signature(|p| load_init_tensors_from_apr(p));
487    }
488
489    #[test]
490    fn transformer_config_matches_llama_370m_constants() {
491        let cfg = llama_370m_transformer_config();
492        assert_eq!(cfg.hidden_size, Llama370MConfig::HIDDEN_DIM);
493        assert_eq!(cfg.num_hidden_layers, Llama370MConfig::NUM_LAYERS);
494        assert_eq!(cfg.num_attention_heads, Llama370MConfig::NUM_HEADS);
495        assert_eq!(cfg.num_kv_heads, Llama370MConfig::NUM_KV_HEADS);
496        assert_eq!(cfg.intermediate_size, Llama370MConfig::INTERMEDIATE_DIM);
497        assert_eq!(cfg.vocab_size, Llama370MConfig::VOCAB_SIZE);
498        assert!((cfg.rope_theta - Llama370MConfig::ROPE_THETA).abs() < f32::EPSILON);
499        assert!((cfg.rms_norm_eps - Llama370MConfig::RMS_NORM_EPS).abs() < f32::EPSILON);
500        assert!(!cfg.use_bias, "INV-ARCH-370M-008: no bias");
501        assert!(cfg.tie_word_embeddings, "INV-ARCH-370M-004: tied embeddings");
502    }
503
504    /// FALSIFY-APR-PRETRAIN-ARCH-002 — `build_transformer_config(None)` returns
505    /// the Llama370M baseline byte-for-byte. Falsifies regression in the §24/§25
506    /// from-scratch path when the polymorphic dispatch was added.
507    ///
508    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
509    #[test]
510    fn build_transformer_config_no_init_matches_llama370m() {
511        let baseline = llama_370m_transformer_config();
512        let result = build_transformer_config(None);
513        assert_eq!(result.hidden_size, baseline.hidden_size);
514        assert_eq!(result.num_attention_heads, baseline.num_attention_heads);
515        assert_eq!(result.num_kv_heads, baseline.num_kv_heads);
516        assert_eq!(result.intermediate_size, baseline.intermediate_size);
517        assert_eq!(result.num_hidden_layers, baseline.num_hidden_layers);
518        assert_eq!(result.vocab_size, baseline.vocab_size);
519        assert_eq!(result.max_position_embeddings, baseline.max_position_embeddings);
520        assert!((result.rms_norm_eps - baseline.rms_norm_eps).abs() < f32::EPSILON);
521        assert!((result.rope_theta - baseline.rope_theta).abs() < f32::EPSILON);
522        assert_eq!(result.use_bias, baseline.use_bias);
523        assert_eq!(result.tie_word_embeddings, baseline.tie_word_embeddings);
524        assert_eq!(result.architecture, baseline.architecture);
525        assert_eq!(result.hf_architecture, baseline.hf_architecture);
526        assert_eq!(result.hf_model_type, baseline.hf_model_type);
527    }
528
529    /// FALSIFY-APR-PRETRAIN-ARCH-003 — `build_transformer_config(Some(cfg))`
530    /// passes through the caller-provided config byte-for-byte. No silent
531    /// defaults, no field overrides. Tests with Qwen2.5-Coder-0.5B shape
532    /// because that is the §49 fine-tune target.
533    ///
534    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
535    #[test]
536    fn build_transformer_config_qwen_init_matches_input() {
537        let qwen = TransformerConfig::qwen2_0_5b();
538        let result = build_transformer_config(Some(&qwen));
539        assert_eq!(result.hidden_size, qwen.hidden_size, "hidden_size");
540        assert_eq!(result.num_attention_heads, qwen.num_attention_heads, "num_attention_heads");
541        assert_eq!(result.num_kv_heads, qwen.num_kv_heads, "num_kv_heads");
542        assert_eq!(result.intermediate_size, qwen.intermediate_size, "intermediate_size");
543        assert_eq!(result.num_hidden_layers, qwen.num_hidden_layers, "num_hidden_layers");
544        assert_eq!(result.vocab_size, qwen.vocab_size, "vocab_size");
545        assert_eq!(
546            result.max_position_embeddings, qwen.max_position_embeddings,
547            "max_position_embeddings"
548        );
549        assert_eq!(result.use_bias, qwen.use_bias, "use_bias");
550        assert_eq!(result.tie_word_embeddings, qwen.tie_word_embeddings, "tie_word_embeddings");
551        assert_eq!(result.architecture, qwen.architecture, "architecture");
552        // GQA-7:1 ratio preserved (Qwen2.5-0.5B: 14 / 2 = 7)
553        assert_eq!(
554            result.num_attention_heads / result.num_kv_heads,
555            7,
556            "GQA ratio must preserve as 7:1 (Qwen2.5-0.5B canonical)"
557        );
558    }
559
560    /// Drift-prevention: dispatch is mutually exclusive — None and Some
561    /// produce different configs (otherwise the polymorphic builder is
562    /// vacuous). Catches a future refactor that accidentally always
563    /// returns Llama370M regardless of init.
564    ///
565    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c — drift prevention.
566    #[test]
567    fn build_transformer_config_dispatch_mutually_exclusive() {
568        let qwen = TransformerConfig::qwen2_0_5b();
569        let none_result = build_transformer_config(None);
570        let some_result = build_transformer_config(Some(&qwen));
571        // The two outputs MUST differ, otherwise the dispatch is broken.
572        assert_ne!(
573            none_result.hidden_size, some_result.hidden_size,
574            "dispatch must differentiate None vs Some — Llama370M hidden=1024 vs Qwen=896"
575        );
576        assert_ne!(
577            none_result.vocab_size, some_result.vocab_size,
578            "dispatch must differentiate None vs Some — Llama370M vocab=50257 vs Qwen=151936"
579        );
580    }
581
582    /// FALSIFY-APR-PRETRAIN-ARCH-007 (decoder branch) — `validate_pretrain_init_arch_compatible`
583    /// returns Ok for a decoder-family config.
584    ///
585    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
586    #[test]
587    fn validate_pretrain_init_arch_accepts_decoder() {
588        let qwen = TransformerConfig::qwen2_0_5b();
589        assert_eq!(qwen.architecture, ModelArchitecture::Decoder);
590        validate_pretrain_init_arch_compatible(&qwen)
591            .expect("decoder-family config (Qwen2.5-0.5B) MUST pass arch-compat gate");
592    }
593
594    /// FALSIFY-APR-PRETRAIN-ARCH-007 (encoder branch) — load-bearing test.
595    /// `validate_pretrain_init_arch_compatible` returns Err naming the
596    /// architecture-family mismatch when given an encoder config (e.g.,
597    /// CodeBERT). Without this gate, the decoder trainer would silently
598    /// build with encoder weights producing nonsense gradients.
599    ///
600    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
601    #[test]
602    fn validate_pretrain_init_arch_rejects_encoder() {
603        // Construct a minimal encoder config (CodeBERT-shaped).
604        let bert = TransformerConfig {
605            hidden_size: 768,
606            num_attention_heads: 12,
607            num_kv_heads: 12,
608            intermediate_size: 3072,
609            num_hidden_layers: 12,
610            vocab_size: 50265,
611            max_position_embeddings: 514,
612            rms_norm_eps: 1e-12,
613            rope_theta: 10_000.0,
614            use_bias: true,
615            head_dim_override: None,
616            architecture: ModelArchitecture::Encoder,
617            hf_architecture: Some("RobertaModel".to_string()),
618            hf_model_type: Some("roberta".to_string()),
619            tie_word_embeddings: false,
620        };
621        let err = validate_pretrain_init_arch_compatible(&bert).expect_err(
622            "encoder-family config (CodeBERT/RoBERTa) MUST fail arch-compat gate — \
623             silent acceptance would corrupt §49 fine-tune trajectory before any \
624             FALSIFY-006 check could measure it",
625        );
626        assert!(
627            err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
628            "error must cite falsifier id: {err}"
629        );
630        assert!(err.contains("Encoder"), "error must name the architecture family: {err}");
631        assert!(
632            err.contains("decoder-only"),
633            "error must explain why this is wrong (decoder trainer): {err}"
634        );
635        assert!(
636            err.contains("RobertaModel"),
637            "error must name the offending hf_architecture: {err}"
638        );
639    }
640
641    /// Drift-prevention: validate_pretrain_init_arch_compatible's behavior on
642    /// the from-scratch baseline (Llama370M) — must Ok. Catches a future
643    /// refactor that accidentally over-rejects decoder configs.
644    #[test]
645    fn validate_pretrain_init_arch_accepts_llama370m_baseline() {
646        let llama = llama_370m_transformer_config();
647        assert_eq!(
648            llama.architecture,
649            ModelArchitecture::Decoder,
650            "Llama370M baseline MUST be Decoder (regression-free)"
651        );
652        validate_pretrain_init_arch_compatible(&llama)
653            .expect("Llama370M baseline (Decoder) MUST pass arch-compat gate");
654    }
655
656    #[test]
657    fn real_step_fn_exhausted_iterator_returns_finite_placeholder() {
658        // Empty iterator means no real batches; we must still return
659        // finite values so the loop's non-divergence + NaN guards see
660        // sane data instead of surprising NaN.
661        //
662        // Construct a minimal trainer WITHOUT running `build_shared_trainer`
663        // because that takes ~5 GB of parameter allocation for 370M —
664        // too expensive for a unit test. Use a tiny synthetic config.
665        let mut tiny = TransformerConfig::llama2_7b();
666        tiny.hidden_size = 64;
667        tiny.num_attention_heads = 4;
668        tiny.num_kv_heads = 4;
669        tiny.num_hidden_layers = 2;
670        tiny.intermediate_size = 128;
671        tiny.vocab_size = 256;
672        let cfg = TransformerTrainConfig::new(tiny);
673        let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
674        let empty_iter: Box<dyn Iterator<Item = LMBatch>> = Box::new(std::iter::empty::<LMBatch>());
675        let mut step = RealStepFn::new(trainer, empty_iter);
676        let (loss, grad_norm) = step.step(0, 1.0e-4, 128);
677        assert!(loss.is_finite(), "exhausted iter must return finite loss");
678        assert!(grad_norm.is_finite(), "grad_norm must be finite");
679        assert!(grad_norm >= 0.0, "INV-TRAIN-008: grad_norm non-negative");
680    }
681
682    #[test]
683    fn real_val_fn_empty_held_out_returns_nan() {
684        let mut tiny = TransformerConfig::llama2_7b();
685        tiny.hidden_size = 64;
686        tiny.num_attention_heads = 4;
687        tiny.num_kv_heads = 4;
688        tiny.num_hidden_layers = 2;
689        tiny.intermediate_size = 128;
690        tiny.vocab_size = 256;
691        let cfg = TransformerTrainConfig::new(tiny);
692        let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
693        let mut val = RealValFn::new(trainer, Vec::new());
694        let loss = val.validate(0);
695        assert!(loss.is_nan(), "empty held_out must surface as NaN to the guard");
696    }
697
698    /// Build a tiny Transformer suitable for unit testing the populate path.
699    /// Uses GQA-1:1 (kv=q) shape — the populate function is shape-agnostic so
700    /// the simpler ratio is fine here.
701    fn tiny_test_transformer() -> Transformer {
702        let mut tiny = TransformerConfig::llama2_7b();
703        tiny.hidden_size = 32;
704        tiny.num_attention_heads = 2;
705        tiny.num_kv_heads = 2;
706        tiny.num_hidden_layers = 2;
707        tiny.intermediate_size = 64;
708        tiny.vocab_size = 16;
709        Transformer::new(&tiny)
710    }
711
712    /// Build a `BTreeMap<String, (Vec<f32>, Vec<usize>)>` from a Transformer's
713    /// `named_parameters()` snapshot. Each tensor is a deterministic ramp
714    /// (i as f32 * 0.001) so populate is byte-identifiable post-set.
715    fn tensors_map_from_transformer(
716        transformer: &Transformer,
717    ) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
718        let mut map = BTreeMap::new();
719        for (name, t) in transformer.named_parameters() {
720            let len = t.len();
721            let data: Vec<f32> = (0..len).map(|i| i as f32 * 0.001).collect();
722            map.insert(name, (data, vec![len]));
723        }
724        map
725    }
726
727    /// Happy path — every model parameter has a matching init entry of correct
728    /// length; populate succeeds and the count matches `named_parameters().len()`.
729    #[test]
730    fn populate_trainer_from_init_tensors_happy_path() {
731        let mut transformer = tiny_test_transformer();
732        let init_tensors = tensors_map_from_transformer(&transformer);
733        let expected_count = transformer.named_parameters().len();
734        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
735        assert!(result.is_ok(), "happy-path populate must succeed: {result:?}");
736        assert_eq!(
737            result.unwrap(),
738            expected_count,
739            "populated count must equal named_parameters().len()"
740        );
741    }
742
743    /// Drift-prevention: extra entries in `init_tensors` that the model does
744    /// NOT expose are silently ignored. This handles tied-embeddings: a Qwen
745    /// APR may publish a separate `lm_head.weight` that the trainer's tied
746    /// model omits.
747    #[test]
748    fn populate_trainer_from_init_tensors_extra_entries_silently_ignored() {
749        let mut transformer = tiny_test_transformer();
750        let mut init_tensors = tensors_map_from_transformer(&transformer);
751        // Inject a fictitious extra parameter that the model does not have.
752        init_tensors
753            .insert("model.layers.999.fictitious.weight".to_string(), (vec![0.0; 4], vec![4]));
754        let expected_count = transformer.named_parameters().len();
755        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
756        assert!(result.is_ok(), "extra init entries must NOT cause Err: {result:?}");
757        assert_eq!(result.unwrap(), expected_count);
758    }
759
760    /// FALSIFY-APR-PRETRAIN-INIT-007 (length mismatch) — when an init tensor
761    /// has the wrong flat length for a known parameter, populate MUST Err
762    /// with the FALSIFIER ID and a per-parameter diagnostic line.
763    #[test]
764    fn populate_trainer_from_init_tensors_rejects_length_mismatch() {
765        let mut transformer = tiny_test_transformer();
766        let mut init_tensors = tensors_map_from_transformer(&transformer);
767        // Corrupt one entry's length to trigger the mismatch path.
768        let any_name = transformer.named_parameters()[0].0.clone();
769        init_tensors.insert(any_name.clone(), (vec![0.0; 7], vec![7]));
770        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
771        assert!(result.is_err(), "length-mismatch must Err, not silently truncate");
772        let err = result.unwrap_err();
773        assert!(
774            err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
775            "error must cite falsifier id; got: {err}"
776        );
777        assert!(err.contains(&any_name), "error must name the offending parameter; got: {err}");
778        assert!(
779            err.contains("init length 7"),
780            "error must report the actual init length; got: {err}"
781        );
782    }
783
784    /// FALSIFY-APR-PRETRAIN-INIT-007 (missing-required) — when a model
785    /// parameter has NO corresponding entry in `init_tensors`, populate MUST
786    /// Err with FALSIFIER ID and a "not present in init APR tensors"
787    /// per-parameter diagnostic. This catches the architecture-mismatch
788    /// class where init was extracted from a different model family.
789    #[test]
790    fn populate_trainer_from_init_tensors_rejects_missing_required_param() {
791        let mut transformer = tiny_test_transformer();
792        let mut init_tensors = tensors_map_from_transformer(&transformer);
793        // Drop one entry to trigger the missing-required path.
794        let any_name = transformer.named_parameters()[0].0.clone();
795        init_tensors.remove(&any_name);
796        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
797        assert!(result.is_err(), "missing-required must Err, not silently leave random init");
798        let err = result.unwrap_err();
799        assert!(
800            err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
801            "error must cite falsifier id; got: {err}"
802        );
803        assert!(err.contains(&any_name), "error must name the missing parameter; got: {err}");
804        assert!(
805            err.contains("not present in init APR"),
806            "error must say what was missing; got: {err}"
807        );
808    }
809
810    /// `build_shared_trainer_with_init(None, None)` returns a trainer with
811    /// the §24/§25 from-scratch Llama370M architecture (regression-free
812    /// dispatch). Asserts the baseline shape via the (hidden, vocab) tuple
813    /// rather than param count to avoid the stale INV-ARCH-370M-001 band
814    /// check in `build_shared_trainer` (a defect outside §50.4 scope —
815    /// param_count=322M vs assert range [366M, 374M]; tracked for follow-up).
816    #[test]
817    fn build_shared_trainer_with_init_none_uses_llama370m_shape() {
818        let trainer = build_shared_trainer_with_init(1.0e-4, 128, 42, None, None)
819            .expect("None case must succeed");
820        let model = trainer.borrow();
821        // The baseline polymorphic dispatch produces a Llama370M-shaped model.
822        // Embedding shape `vocab × hidden` is the cleanest non-stale check.
823        let embed_len = model.model().named_parameters()[0].1.len();
824        let expected_embed_len = Llama370MConfig::VOCAB_SIZE * Llama370MConfig::HIDDEN_DIM;
825        assert_eq!(
826            embed_len,
827            expected_embed_len,
828            "init=None must produce Llama370M-shaped embedding (vocab={} × hidden={})",
829            Llama370MConfig::VOCAB_SIZE,
830            Llama370MConfig::HIDDEN_DIM
831        );
832    }
833
834    /// `build_shared_trainer_with_init(Some, None)` and the inverse must
835    /// fail-fast — both args are paired and either both Some or both None.
836    /// Drift-prevention: catches a future caller that forgets to pass one.
837    #[test]
838    fn build_shared_trainer_with_init_rejects_unpaired_args() {
839        // arch Some, path None
840        let cfg = TransformerConfig::qwen2_0_5b();
841        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), None);
842        assert!(result.is_err(), "unpaired (arch=Some, path=None) must Err");
843        // arch None, path Some
844        let dummy_path = std::path::PathBuf::from("/dev/null");
845        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, None, Some(&dummy_path));
846        assert!(result.is_err(), "unpaired (arch=None, path=Some) must Err");
847    }
848
849    /// `build_shared_trainer_with_init(Some(encoder), Some(path))` rejects
850    /// the encoder family BEFORE attempting tensor load. Drift-prevention for
851    /// FALSIFY-APR-PRETRAIN-ARCH-007 at the trainer-builder integration level.
852    #[test]
853    fn build_shared_trainer_with_init_rejects_encoder_family() {
854        let mut encoder_cfg = TransformerConfig::qwen2_0_5b();
855        encoder_cfg.architecture = ModelArchitecture::Encoder;
856        let dummy_path = std::path::PathBuf::from("/nonexistent/encoder.apr");
857        let result =
858            build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&encoder_cfg), Some(&dummy_path));
859        let err = match result {
860            Ok(_) => panic!("encoder family must be rejected before tensor load"),
861            Err(e) => e,
862        };
863        assert!(
864            err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
865            "error must cite falsifier id; got: {err}"
866        );
867    }
868
869    /// `build_shared_trainer_with_init(Some(decoder), Some(missing_path))`
870    /// proceeds past the family check and FAILS at tensor load with a
871    /// FALSIFY-006 error. Pins the failure ordering: arch validation first,
872    /// then tensor load.
873    #[test]
874    fn build_shared_trainer_with_init_decoder_family_proceeds_to_tensor_load() {
875        let cfg = TransformerConfig::qwen2_0_5b();
876        let dummy_path = std::path::PathBuf::from("/nonexistent/decoder.apr");
877        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), Some(&dummy_path));
878        let err = match result {
879            Ok(_) => panic!("missing tensor path must Err"),
880            Err(e) => e,
881        };
882        assert!(
883            err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
884            "decoder family proceeds to tensor load; failure cites INIT-006 not ARCH-007; got: {err}"
885        );
886        assert!(
887            !err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
888            "decoder family must NOT trigger encoder-rejection; got: {err}"
889        );
890    }
891
892    /// FALSIFY-H4-INIT-STATS-001 (SHIP-TWO §61 H4A bisect):
893    /// `load_init_tensors_from_apr` on the canonical Qwen2.5-Coder-0.5B-Instruct
894    /// APR file MUST produce sensibly-distributed weights:
895    ///   - `model.embed_tokens.weight` mean ≈ 0 (within ±0.01)
896    ///   - `model.embed_tokens.weight` std in [0.01, 0.1] (HF LLaMA init = 0.02)
897    ///   - `model.norm.weight` mean ≈ 1.0 (RMSNorm pretrained scale)
898    ///
899    /// CONTEXT: §61 evidence shows val_loss=19.80 > ln(vocab)=17.21 at
900    /// step 1, indicating the loaded model produces sub-random predictions.
901    /// Four candidate hypotheses (H4A tied weights, H4B layout, H4C norm
902    /// scale, H4D residual stream). This test bisects H4A+H4C: if any of
903    /// the loaded tensor stats are wildly out-of-range, the load itself
904    /// is corrupt; if all stats look correct, the bug is in the forward
905    /// path (H4B layout or H4D residual).
906    ///
907    /// Host-gated: requires a canonical Qwen 0.5B init APR. Tries the
908    /// "fresh" path first (current `apr import` of HF safetensors,
909    /// preserves BF16 dtype tag); falls back to the older "fp16" path
910    /// (legacy import, mis-tagged as F16). Skips if neither present.
911    ///
912    /// The legacy file demonstrates the H4 dtype-mislabel defect class:
913    /// safetensors source is BF16, old `apr import` wrote bytes raw
914    /// but tagged dtype as F16, aprender's loader then read bytes as
915    /// F16 and produced distorted values. The fresh path preserves
916    /// BF16 correctly. Element-0 cross-checks agree with the
917    /// safetensors source under BF16 decode.
918    #[test]
919    fn falsify_h4_init_stats_qwen_embed_norm_sensible() {
920        let fresh = std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-fresh.apr");
921        let legacy =
922            std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-instruct-fp16.apr");
923        let path = if fresh.exists() {
924            fresh
925        } else if legacy.exists() {
926            legacy
927        } else {
928            eprintln!("[falsify-h4-init-stats-001] skipping: host lacks Qwen 0.5B APR");
929            return;
930        };
931        let _ = path; // silence unused if branches
932        if !path.exists() {
933            eprintln!("[falsify-h4-init-stats-001] skipping: host lacks {}", path.display());
934            return;
935        }
936        // H4 root-cause probe: directly inspect the APR's dtype tag to
937        // verify whether the F16 vs BF16 distinction was preserved
938        // through `apr import`.
939        {
940            use aprender::format::v2::AprV2Reader;
941            let bytes = std::fs::read(path).expect("read APR");
942            let reader = AprV2Reader::from_bytes(&bytes).expect("parse APR v2");
943            for name in ["model.layers.0.self_attn.q_proj.bias", "model.norm.weight"] {
944                if let Some(entry) = reader.get_tensor(name) {
945                    eprintln!(
946                        "[h4-init-dtype] {name}: dtype={:?} shape={:?}",
947                        entry.dtype, entry.shape
948                    );
949                }
950            }
951        }
952        let tensors = match load_init_tensors_from_apr(path) {
953            Ok(t) => t,
954            Err(e) => {
955                panic!("FALSIFY-H4-INIT-STATS-001: load_init_tensors_from_apr failed: {e}");
956            }
957        };
958
959        // Required tensors
960        let embed = tensors
961            .get("model.embed_tokens.weight")
962            .unwrap_or_else(|| panic!("missing model.embed_tokens.weight in init APR"));
963        let norm = tensors
964            .get("model.norm.weight")
965            .unwrap_or_else(|| panic!("missing model.norm.weight in init APR"));
966
967        let stats = |name: &str, data: &[f32]| -> (f64, f64, f32, f32) {
968            let n = data.len() as f64;
969            let mean = data.iter().map(|&v| v as f64).sum::<f64>() / n;
970            let var = data
971                .iter()
972                .map(|&v| {
973                    let d = v as f64 - mean;
974                    d * d
975                })
976                .sum::<f64>()
977                / n;
978            let std = var.sqrt();
979            let min = data.iter().copied().fold(f32::INFINITY, f32::min);
980            let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
981            eprintln!(
982                "[h4-init-stats] {name}: n={n} mean={mean:.5} std={std:.5} min={min:.4} max={max:.4}"
983            );
984            (mean, std, min, max)
985        };
986        // H4-DTYPE-MISLABEL: dump first 4 element-0 values to compare
987        // with safetensors source (decoded as BF16). If the APR loader
988        // mis-decodes BF16 bytes as F16, values will diverge.
989        {
990            let q = tensors.get("model.layers.0.self_attn.q_proj.bias").unwrap();
991            eprintln!(
992                "[h4-dtype-mislabel] q_proj.bias L0[0..6] (aprender F16-decoded): {:?}",
993                &q.0[..6]
994            );
995            let n = tensors.get("model.norm.weight").unwrap();
996            eprintln!(
997                "[h4-dtype-mislabel] model.norm.weight[0..6] (aprender F16-decoded): {:?}",
998                &n.0[..6]
999            );
1000        }
1001
1002        let (em, es, _, _) = stats("model.embed_tokens.weight", &embed.0);
1003        let (nm, ns, _, _) = stats("model.norm.weight", &norm.0);
1004
1005        // H4C bisect: dump per-layer norm stats. Standard RMSNorm
1006        // weights are near 1.0 (init=1.0, trained drift to ~0.1-2.0).
1007        // Mean > 5 across layers indicates a load-time scale-corruption.
1008        for layer_idx in [0_usize, 5, 11, 23] {
1009            for kind in ["input_layernorm", "post_attention_layernorm"] {
1010                let key = format!("model.layers.{layer_idx}.{kind}.weight");
1011                if let Some(t) = tensors.get(&key) {
1012                    stats(&key, &t.0);
1013                }
1014            }
1015        }
1016        for kind in [
1017            "model.layers.0.self_attn.q_proj.weight",
1018            "model.layers.0.self_attn.q_proj.bias",
1019            "model.layers.0.mlp.gate_proj.weight",
1020            "model.layers.0.mlp.down_proj.weight",
1021        ] {
1022            if let Some(t) = tensors.get(kind) {
1023                stats(kind, &t.0);
1024            }
1025        }
1026
1027        // Embedding init bound: HF LLaMA init normal(0, 0.02). After
1028        // pretraining the std grows but typically stays in [0.01, 0.1].
1029        // mean should be near 0 (well-centered).
1030        assert!(
1031            em.abs() < 0.05,
1032            "FALSIFY-H4-INIT-STATS-001: embed mean={em} > 0.05; weights are not centered. \
1033             Possible f16→f32 sign-bit corruption or wrong byte-order."
1034        );
1035        assert!(
1036            (0.005..=0.5).contains(&es),
1037            "FALSIFY-H4-INIT-STATS-001: embed std={es} outside [0.005, 0.5]; weights are not \
1038             distributed like trained transformer init. Possible f16 mantissa misread or \
1039             scale corruption."
1040        );
1041
1042        // RMSNorm init: weights are ~1.0 (sqrt(2)≈1.41 in some configs).
1043        // After training they stay close to 1, sometimes drifting up to ~10.
1044        assert!(
1045            nm > 0.01 && nm < 100.0,
1046            "FALSIFY-H4-INIT-STATS-001: norm mean={nm} outside [0.01, 100]; RMSNorm scale \
1047             load is corrupt. Trained pretrained values are typically near 1.0."
1048        );
1049        assert!(
1050            ns < 100.0,
1051            "FALSIFY-H4-INIT-STATS-001: norm std={ns} > 100; RMSNorm has explosive variance. \
1052             Tensor load is corrupt."
1053        );
1054    }
1055
1056    /// FALSIFY-H4-CPU-FORWARD-001 (H4 residual cascade — bisect to CPU vs CUDA):
1057    /// CPU `aprender::Transformer::forward` on a populated Qwen 0.5B model
1058    /// MUST produce sensibly-distributed logits. Host-gated test that
1059    /// bisects whether the val_loss > ln(vocab) defect is in the
1060    /// populate path / CPU forward (RED here = bug there) or in CUDA
1061    /// (GREEN here, RED in eval_batch = bug in CUDA path).
1062    #[test]
1063    fn falsify_h4_cpu_forward_qwen_logits_sensible() {
1064        let fresh = std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-fresh.apr");
1065        let legacy =
1066            std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-instruct-fp16.apr");
1067        let path = if fresh.exists() {
1068            fresh
1069        } else if legacy.exists() {
1070            legacy
1071        } else {
1072            eprintln!("[falsify-h4-cpu-forward-001] skipping: host lacks Qwen 0.5B APR");
1073            return;
1074        };
1075
1076        let tensors = load_init_tensors_from_apr(path).expect("load_init_tensors_from_apr");
1077        let cfg = TransformerConfig::qwen2_0_5b();
1078        let mut transformer = Transformer::new(&cfg);
1079        let populated = populate_trainer_from_init_tensors(&mut transformer, &tensors)
1080            .expect("populate_trainer_from_init_tensors");
1081        eprintln!("[falsify-h4-cpu-forward-001] populated {populated} tensors");
1082
1083        let token_ids = vec![100_u32];
1084        let logits = transformer.forward(&token_ids);
1085        let data = logits.data();
1086        let slice = data.as_slice().expect("logits contiguous");
1087
1088        let mut nan_count = 0usize;
1089        let mut inf_count = 0usize;
1090        let mut min = f32::INFINITY;
1091        let mut max = f32::NEG_INFINITY;
1092        let mut sum = 0.0_f64;
1093        let mut sum_sq = 0.0_f64;
1094        let mut argmax_idx = 0_usize;
1095        for (i, &v) in slice.iter().enumerate() {
1096            if v.is_nan() {
1097                nan_count += 1;
1098            } else if v.is_infinite() {
1099                inf_count += 1;
1100            } else {
1101                if v < min {
1102                    min = v;
1103                }
1104                if v > max {
1105                    max = v;
1106                    argmax_idx = i;
1107                }
1108                sum += v as f64;
1109                sum_sq += (v as f64) * (v as f64);
1110            }
1111        }
1112        let n = slice.len() as f64;
1113        let mean = sum / n;
1114        let std = (sum_sq / n - mean * mean).sqrt();
1115
1116        eprintln!(
1117            "[falsify-h4-cpu-forward-001] token=100 logits: n={} nan={nan_count} inf={inf_count} \
1118             min={min:.4} max={max:.4} mean={mean:.4} std={std:.4} argmax={argmax_idx}",
1119            slice.len()
1120        );
1121
1122        assert_eq!(nan_count, 0, "logits contain NaN — forward corruption");
1123        assert_eq!(inf_count, 0, "logits contain Inf — forward corruption");
1124        assert!(
1125            std > 0.01,
1126            "FALSIFY-H4-CPU-FORWARD-001: logits std={std} < 0.01 — essentially constant"
1127        );
1128        let peak_to_mean = (max as f64 - mean).abs() / std.max(1e-9);
1129        assert!(
1130            peak_to_mean > 1.5,
1131            "FALSIFY-H4-CPU-FORWARD-001: peak-to-mean ratio = {peak_to_mean} < 1.5 — \
1132             logits are essentially uniform"
1133        );
1134        assert!(
1135            (argmax_idx as u32) < cfg.vocab_size as u32,
1136            "FALSIFY-H4-CPU-FORWARD-001: argmax_idx={argmax_idx} >= vocab_size={}",
1137            cfg.vocab_size
1138        );
1139    }
1140}