Skip to main content

entrenar/train/
pretrain_real.rs

1//! Real-corpus `StepFn` / `ValFn` for MODEL-2 pretrain MVP (task #111).
2//!
3//! Bridges the model-agnostic `PretrainLoop` (`pretrain.rs`) to the
4//! 370M Llama scaffold (`models/llama_370m.rs`) by wiring a real
5//! `TransformerTrainer` through the `StepFn` and `ValFn` traits.
6//!
7//! The loop drive replaces the `LinearDecaySynthetic` / `ScriptedVal`
8//! pair used for GATE-TRAIN-005/007/008 wiring verification (task #105)
9//! with a real forward + backward + optimizer step and a real held-out
10//! validation forward pass.
11//!
12//! Contract obligations discharged:
13//! - INV-ARCH-370M-001 (param count in [366M, 374M]) via `debug_assert_eq!`
14//! - INV-TRAIN-001 (per-step metrics — 6 fields via PretrainLoop)
15//! - INV-TRAIN-007 (no NaN/Inf — the loop aborts on first non-finite)
16//!
17//! Deferred (task #111 follow-ups):
18//! - Real grad_norm (currently reports a placeholder; needs
19//!   TransformerTrainer extension to surface pre-clip norm)
20//! - INV-TRAIN-003 (real optimizer-state sha256 over AdamW m/v/t buffers)
21
22use crate::models::llama_370m::Llama370MConfig;
23use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
24use crate::train::transformer_trainer::{LMBatch, TransformerTrainConfig, TransformerTrainer};
25use crate::transformer::{ModelArchitecture, Transformer, TransformerConfig};
26use crate::Tensor;
27use std::cell::RefCell;
28use std::collections::BTreeMap;
29use std::path::Path;
30use std::rc::Rc;
31
32/// Shared mutable ownership of the `TransformerTrainer` — both
33/// `RealStepFn` (training steps) and `RealValFn` (forward-only
34/// validation) clone this `Rc`.
35pub type SharedTrainer = Rc<RefCell<TransformerTrainer>>;
36
37/// Load tensors from an APR file as the read-half of `apr pretrain --init`.
38///
39/// Per `apr-pretrain-arch-polymorphic-v1` §init_load_semantics (PR #1473),
40/// the loader is REUSED, not reimplemented — this function is a thin wrapper
41/// over `aprender::format::converter::convert_report::load_model_tensors`,
42/// which is the same machinery `apr export` and `apr inspect` use. No
43/// duplicate APR parser; one source of truth.
44///
45/// Returns a map of `tensor_name -> (flat_f32_data, shape)`. The HF naming
46/// convention is preserved (e.g., `model.embed_tokens.weight`); reconciling
47/// against the trainer's parameter names is step 5f.3 (the population step).
48///
49/// Discharges from `apr-pretrain-arch-polymorphic-v1`:
50///   - §init_load_semantics invariant: "Loader is reused, not reimplemented"
51///   - FALSIFY-006 (init_loss < 6.0) at READ-COMPILE-BIND level: this
52///     function is the read half. Full FALSIFY-006 discharge requires
53///     5f.3 (population) + 5g (LIVE 500-step fine-tune).
54///
55/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.2.
56///
57/// # Errors
58///
59/// Returns Err if the APR file:
60/// - Does not exist (filesystem I/O error)
61/// - Has invalid magic bytes (not APR\\0 or APRN)
62/// - Has a corrupted tensor index
63/// - Contains tensors with unsupported dtype (non-F32)
64pub fn load_init_tensors_from_apr(
65    path: impl AsRef<Path>,
66) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String> {
67    let path_ref = path.as_ref();
68    aprender::format::converter::load_model_tensors(path_ref).map_err(|e| {
69        format!(
70            "FALSIFY-APR-PRETRAIN-INIT-006: failed to load init tensors from APR file {}: {e}",
71            path_ref.display()
72        )
73    })
74}
75
76/// Reject an init `TransformerConfig` whose architecture family is incompatible
77/// with the pretrain target (decoder-only LM training).
78///
79/// Per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature
80/// (PR #1473), wrong-arch APR (e.g., a CodeBERT/RoBERTa encoder model)
81/// MUST be FAIL-FAST not silent-truncate. Without this gate, an operator
82/// who points `--init` at e.g. `microsoft/codebert-base.apr` would silently
83/// load encoder weights into a decoder-shaped trainer, producing nonsense
84/// gradients that the divergence guard catches LATE (after multiple epochs).
85///
86/// Discharges FALSIFY-APR-PRETRAIN-ARCH-007 at PARTIAL_ALGORITHM_LEVEL.
87///
88/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
89///
90/// # Errors
91///
92/// Returns Err with a clear architecture-family-mismatch message when:
93/// - `cfg.architecture` is `ModelArchitecture::Encoder` (BERT/RoBERTa/CodeBERT)
94///
95/// Future expansion can add other family checks (e.g., reject hybrid SSM
96/// architectures whose forward pass differs from the standard decoder loop).
97pub fn validate_pretrain_init_arch_compatible(cfg: &TransformerConfig) -> Result<(), String> {
98    match cfg.architecture {
99        ModelArchitecture::Decoder => Ok(()),
100        ModelArchitecture::Encoder => Err(format!(
101            "FALSIFY-APR-PRETRAIN-ARCH-007: --init checkpoint has architecture=Encoder \
102             (e.g., BERT/RoBERTa/CodeBERT) but the pretrain trainer is decoder-only \
103             (Llama/Qwen-class causal LMs). Loading encoder weights into a decoder \
104             trainer would produce nonsense gradients. Architectural details: \
105             hidden_size={}, num_layers={}, vocab_size={}, hf_architecture={:?}",
106            cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size, cfg.hf_architecture
107        )),
108    }
109}
110
111/// SPEC §86 / INV-INIT-ARCH-MATCH-001 — infer the family-arch slug from
112/// tensor names alone (no tensor data needed).
113///
114/// Mirrors the heavyweight `infer_architecture_from_names` in
115/// `aprender-core::format::converter::tokenizer_loader` but takes only
116/// names, so callers don't have to materialize the full F32 tensor map.
117/// Used by `validate_init_arch_matches_tensor_evidence` to catch the
118/// §86 case (Llama-stamped metadata + Qwen2-tensored APR) at the gate.
119///
120/// Returns one of: "qwen3", "qwen2", "llama", "mamba", "rwkv",
121/// "gpt-neox", "opt", "bert", "gpt2", "unknown".
122#[must_use]
123pub fn family_from_tensor_names<'a, I>(names: I) -> &'static str
124where
125    I: IntoIterator<Item = &'a str>,
126{
127    // We iterate once but need to check multiple predicates; collect names
128    // into a Vec<&str> so the predicates can each scan independently. For
129    // an Qwen2-0.5B APR (~291 tensors) this is negligible.
130    let names: Vec<&str> = names.into_iter().collect();
131
132    let any_contains = |needle: &str| names.iter().any(|k| k.contains(needle));
133    let any_starts_with = |pfx: &str| names.iter().any(|k| k.starts_with(pfx));
134
135    // PMAT-546: Mamba (SSM)
136    if any_contains("mixer.in_proj") || any_contains("mixer.out_proj") {
137        return "mamba";
138    }
139    // PMAT-546: RWKV
140    if any_starts_with("rwkv.blocks.") || any_contains("blocks.0.att.") {
141        return "rwkv";
142    }
143    // GH-311: GPT-NeoX (must precede model.layers)
144    if any_starts_with("gpt_neox.") {
145        return "gpt-neox";
146    }
147    // GH-311: OPT
148    if any_starts_with("model.decoder.layers.") {
149        return "opt";
150    }
151    // GH-311: BERT
152    if any_starts_with("bert.") {
153        return "bert";
154    }
155    let has_model_layers = any_contains("model.layers");
156    let has_transformer_h = any_contains("transformer.h")
157        || names.iter().any(|k| k.starts_with("h.") && k.contains(".attn."));
158    let has_blk = any_contains("blk.");
159    if has_model_layers {
160        // Qwen3 — unique QK-norm signal
161        if any_contains("self_attn.q_norm.weight") {
162            return "qwen3";
163        }
164        // Qwen2 — distinguished from Llama by attention bias / fused QKV
165        if any_contains("self_attn.q_proj.bias") || any_contains("qkv_proj.weight") {
166            return "qwen2";
167        }
168        return "llama";
169    }
170    if has_transformer_h {
171        return "gpt2";
172    }
173    if has_blk {
174        return "unknown"; // GGUF-naming, can't disambiguate
175    }
176    "unknown"
177}
178
179/// SPEC §86 / INV-INIT-ARCH-MATCH-001 — normalize an APR metadata
180/// `architecture` string to the canonical family slug used by
181/// `family_from_tensor_names`.
182///
183/// Handles the three forms the field can take:
184///
185/// - **HF class name** (e.g., `"Qwen2ForCausalLM"`, `"LlamaForCausalLM"`)
186///   — the §82 P0-H fallback stamps this into the family field when
187///   `hf_architecture` is absent.
188/// - **Family slug** (e.g., `"qwen2"`, `"llama"`) — the canonical form
189///   from a properly-imported APR (post-P0-K).
190/// - **Capitalised legacy** (e.g., `"Qwen2"`, `"Llama"`) — older imports.
191///
192/// Returns `None` for `"unknown"` or unmappable strings — the caller
193/// should treat those as "no metadata claim" and skip the cross-check.
194#[must_use]
195pub fn normalize_metadata_arch_family(arch: &str) -> Option<&'static str> {
196    match arch {
197        // HF class names (P0-H §82 fallback stamps these into the family field)
198        "Qwen2ForCausalLM" | "Qwen2.5ForCausalLM" => Some("qwen2"),
199        "Qwen3ForCausalLM" | "Qwen3MoeForCausalLM" => Some("qwen3"),
200        "LlamaForCausalLM" => Some("llama"),
201        "MistralForCausalLM" => Some("llama"), // Mistral shares Llama tensor shape
202        "Phi3ForCausalLM" | "PhiForCausalLM" => Some("llama"), // Phi shares Llama family for our purposes
203        "GPT2LMHeadModel" => Some("gpt2"),
204        "GPTNeoXForCausalLM" => Some("gpt-neox"),
205        "MambaForCausalLM" => Some("mamba"),
206        "RwkvForCausalLM" | "Rwkv6ForCausalLM" => Some("rwkv"),
207        "BertModel" | "BertForMaskedLM" => Some("bert"),
208        "OPTForCausalLM" => Some("opt"),
209        // Family slugs (canonical / lowercase)
210        "qwen2" | "qwen2.5" | "qwen" => Some("qwen2"),
211        "qwen3" | "qwen3_5" | "qwen3.5" => Some("qwen3"),
212        "llama" | "mistral" | "phi" | "phi3" | "phi4" => Some("llama"),
213        "gpt2" => Some("gpt2"),
214        "gpt-neox" | "gpt_neox" | "gptneox" | "pythia" => Some("gpt-neox"),
215        "mamba" => Some("mamba"),
216        "rwkv" => Some("rwkv"),
217        "bert" => Some("bert"),
218        "opt" => Some("opt"),
219        // Capitalised legacy
220        "Qwen2" | "Qwen2.5" | "Qwen" => Some("qwen2"),
221        "Qwen3" => Some("qwen3"),
222        "Llama" | "Mistral" | "Phi" | "Phi3" | "Phi4" => Some("llama"),
223        "Gpt2" | "GPT2" => Some("gpt2"),
224        // Unknown / unmappable — caller should skip the cross-check
225        _ => None,
226    }
227}
228
229/// SPEC §86 / INV-INIT-ARCH-MATCH-001 — FAIL-FAST when an APR's
230/// metadata `architecture` claim contradicts what its tensor names imply.
231///
232/// This catches the §86 silent-failure case at the gate instead of at
233/// init eval: a pre-P0-K APR with `architecture = "LlamaForCausalLM"`
234/// (the §82 P0-H fallback) and Qwen2-style tensor names produces
235/// random-init training instead of resume-from-checkpoint, with
236/// val_loss at step 0 disagreeing with the init's recorded val_loss
237/// by orders of magnitude. The fix at the framework level is shipped
238/// via PR #1742 (P0-K stamping); this invariant prevents existing
239/// pre-P0-K artifacts from training silently from random init.
240///
241/// Discharges INV-INIT-ARCH-MATCH-001 in `contracts/apr-pretrain-from-init-v1.yaml`
242/// (forthcoming, scope-noted in SPEC §86.6).
243///
244/// # Errors
245///
246/// Returns Err with a clear naming-both-claims message when the
247/// metadata family slug differs from the tensor-evidence family slug.
248/// When the metadata claim is `"unknown"` (or doesn't parse to a known
249/// family) the gate is skipped — no false-positive on novel architectures.
250///
251/// # Salvage path
252///
253/// Operators with pre-P0-K Llama-stamped Qwen2 checkpoints can
254/// restamp the metadata in place via the §86.4 recipe:
255///
256/// ```ignore
257/// apr stamp <pre-p0k.apr> --architecture qwen2 --hf-architecture Qwen2ForCausalLM \
258///                          -o <stamped.apr>
259/// ```
260///
261/// See PR #1757 (apr stamp HF identity extension).
262pub fn validate_init_arch_matches_tensor_evidence(
263    metadata_arch: Option<&str>,
264    init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
265) -> Result<(), String> {
266    // If the metadata claim is absent or unmappable, we have no claim to
267    // contradict — skip the cross-check (a novel arch is not §86's case).
268    let Some(metadata_family) = metadata_arch.and_then(normalize_metadata_arch_family) else {
269        return Ok(());
270    };
271
272    let tensor_family = family_from_tensor_names(init_tensors.keys().map(String::as_str));
273
274    // If tensor inference returns "unknown" (e.g., GGUF blk.* names that
275    // can't be disambiguated), we trust the metadata claim. Only fail
276    // when BOTH inferences produce concrete family slugs AND they differ.
277    if tensor_family == "unknown" {
278        return Ok(());
279    }
280
281    if metadata_family != tensor_family {
282        return Err(format!(
283            "FALSIFY-INIT-ARCH-MATCH-001: --init APR metadata claims architecture \
284             family `{metadata_family}` (from `{}`) but tensor naming implies \
285             family `{tensor_family}`. This is the SPEC §86 silent-failure pattern: \
286             pre-P0-K APRs with the §82 P0-H \"LlamaForCausalLM\" fallback stamp + \
287             Qwen2 tensors load as random-init and train from scratch. Salvage with \
288             `apr stamp <input.apr> --architecture {tensor_family} --hf-architecture \
289             {} -o <stamped.apr>` (see PR #1757 / SPEC §86.4) then re-run \
290             `apr pretrain --init <stamped.apr>`.",
291            metadata_arch.unwrap_or("?"),
292            // Synthesize the canonical HF class name from the tensor-evidence family
293            match tensor_family {
294                "qwen2" => "Qwen2ForCausalLM",
295                "qwen3" => "Qwen3ForCausalLM",
296                "llama" => "LlamaForCausalLM",
297                "gpt2" => "GPT2LMHeadModel",
298                "gpt-neox" => "GPTNeoXForCausalLM",
299                "mamba" => "MambaForCausalLM",
300                "rwkv" => "RwkvForCausalLM",
301                "bert" => "BertModel",
302                "opt" => "OPTForCausalLM",
303                other => other,
304            }
305        ));
306    }
307
308    Ok(())
309}
310
311/// Populate a `Transformer`'s parameters from an `init_tensors` BTreeMap.
312///
313/// For each parameter the `Transformer` exposes via `named_parameters()`, look
314/// up the same HF-naming key in `init_tensors` and replace the parameter's
315/// `Tensor` with `Tensor::from_vec(data, requires_grad=true)`. The parameter
316/// remains trainable (gradients still flow) — this is fine-tune init, not
317/// frozen-encoder.
318///
319/// Strictness rules:
320/// - **Every** model parameter must have a matching entry in `init_tensors`.
321///   Missing entries return Err naming the unmatched parameter; this catches
322///   the case where the init APR was extracted from a different architecture.
323/// - **Length** of each init entry must match the model parameter's length
324///   (computed from the model's existing tensor `len()`). Mismatch returns Err.
325/// - **Extra** entries in `init_tensors` are silently ignored. This handles
326///   `tie_word_embeddings`: a Qwen2.5 APR may publish a separate `lm_head.weight`
327///   tensor that the model omits when ties are enabled.
328///
329/// Discharges from `apr-pretrain-arch-polymorphic-v1` §init_load_semantics:
330/// - Population invariant: "Init tensors populate trainer parameters
331///   byte-equivalent to source"
332/// - FALSIFY-APR-PRETRAIN-INIT-007 (population step) at PARTIAL_ALGORITHM_LEVEL.
333///
334/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.3.
335///
336/// # Errors
337///
338/// Returns Err if any model parameter is missing from `init_tensors` or if
339/// any matched entry has a wrong length. The error message lists up to the
340/// first 5 problem parameters and the total count.
341pub fn populate_trainer_from_init_tensors(
342    transformer: &mut Transformer,
343    init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
344) -> Result<usize, String> {
345    let expected: Vec<(String, usize)> =
346        transformer.named_parameters().into_iter().map(|(name, t)| (name, t.len())).collect();
347    let mut populated = 0usize;
348    let mut errors: Vec<String> = Vec::new();
349
350    for (name, expected_len) in &expected {
351        match init_tensors.get(name) {
352            Some((data, _shape)) => {
353                if data.len() != *expected_len {
354                    errors.push(format!(
355                        "{name}: init length {} != trainer expected {expected_len}",
356                        data.len()
357                    ));
358                    continue;
359                }
360                let tensor = Tensor::from_vec(data.clone(), true);
361                if !transformer.set_named_parameter(name, tensor) {
362                    errors.push(format!("{name}: set_named_parameter rejected the assignment"));
363                    continue;
364                }
365                populated += 1;
366            }
367            None => {
368                errors.push(format!("{name}: not present in init APR tensors"));
369            }
370        }
371    }
372
373    if !errors.is_empty() {
374        let total = errors.len();
375        let head = errors.iter().take(5).cloned().collect::<Vec<_>>().join("; ");
376        return Err(format!(
377            "FALSIFY-APR-PRETRAIN-INIT-007: populate_trainer_from_init_tensors \
378             failed for {total} parameter(s); first {} of {total}: {head}",
379            errors.len().min(5)
380        ));
381    }
382
383    Ok(populated)
384}
385
386/// Build a `TransformerConfig` field-for-field from `Llama370MConfig::*`
387/// constants (the contract-frozen MODEL-2 370M architecture).
388pub fn llama_370m_transformer_config() -> TransformerConfig {
389    TransformerConfig {
390        hidden_size: Llama370MConfig::HIDDEN_DIM,
391        num_attention_heads: Llama370MConfig::NUM_HEADS,
392        num_kv_heads: Llama370MConfig::NUM_KV_HEADS,
393        intermediate_size: Llama370MConfig::INTERMEDIATE_DIM,
394        num_hidden_layers: Llama370MConfig::NUM_LAYERS,
395        vocab_size: Llama370MConfig::VOCAB_SIZE,
396        max_position_embeddings: Llama370MConfig::MAX_POSITION_EMBEDDINGS,
397        rms_norm_eps: Llama370MConfig::RMS_NORM_EPS,
398        rope_theta: Llama370MConfig::ROPE_THETA,
399        use_bias: false,
400        head_dim_override: None,
401        architecture: ModelArchitecture::Decoder,
402        hf_architecture: Some("LlamaForCausalLM".into()),
403        hf_model_type: Some("llama".into()),
404        tie_word_embeddings: true,
405    }
406}
407
408/// Polymorphic builder per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature.
409///
410/// Discharges FALSIFY-APR-PRETRAIN-ARCH-002 (init=None preserves Llama370M baseline)
411/// and FALSIFY-APR-PRETRAIN-ARCH-003 (init=Some passes through extracted config).
412///
413/// Behaviour:
414///   init = None  → return `llama_370m_transformer_config()`, the §24/§25
415///                  from-scratch baseline. NO regression.
416///   init = Some  → clone the caller-extracted `TransformerConfig` byte-for-byte.
417///                  No silent defaults, no field overrides.
418///
419/// The caller is responsible for actually reading the APR file and producing the
420/// `TransformerConfig` (typically via `TransformerConfig::from_apr_metadata` from
421/// `transformer::config`). Decoupling the dispatch from the file I/O keeps
422/// `aprender-train` free of `aprender-serve` (the APR loader) as a build dep.
423///
424/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
425pub fn build_transformer_config(init: Option<&TransformerConfig>) -> TransformerConfig {
426    match init {
427        None => llama_370m_transformer_config(),
428        Some(cfg) => cfg.clone(),
429    }
430}
431
432/// Build a `TransformerTrainConfig` with MODEL-2 v2-remedy defaults
433/// (LR=5e-5, AdamW defaults, fp32, seed=42 set by caller).
434pub fn llama_370m_train_config(lr: f32, seq_length: usize, seed: u64) -> TransformerTrainConfig {
435    let model_cfg = llama_370m_transformer_config();
436    let mut cfg = TransformerTrainConfig::new(model_cfg);
437    cfg.lr = lr;
438    cfg.max_seq_len = seq_length;
439    cfg.seed = seed;
440    cfg
441}
442
443/// `StepFn` impl that pulls one `LMBatch` from an owned iterator and
444/// runs a real forward + backward + AdamW step through the shared
445/// `TransformerTrainer`.
446pub struct RealStepFn {
447    trainer: SharedTrainer,
448    batches: Box<dyn Iterator<Item = LMBatch>>,
449}
450
451impl RealStepFn {
452    pub fn new(trainer: SharedTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
453        Self { trainer, batches }
454    }
455}
456
457impl StepFn for RealStepFn {
458    fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
459        // Pull one batch; if the shard stream is exhausted before the
460        // loop plans to stop, emit a tiny finite placeholder so
461        // GATE-TRAIN-007 (NaN/Inf guard) does not mis-fire — the
462        // divergence guard (GATE-TRAIN-005) will correctly not abort
463        // on a flat tail.
464        let Some(batch) = self.batches.next() else {
465            return (1.0, 1.0);
466        };
467        let mut trainer = self.trainer.borrow_mut();
468        let loss = trainer.train_batch(&batch);
469        // TODO(task #111 follow-up): expose AdamW pre-clip grad norm.
470        // Placeholder = 1.0 keeps INV-TRAIN-007 satisfied (finite) and
471        // INV-TRAIN-008 satisfied (≥ 0); the real grad norm is a
472        // downstream ticket that needs TransformerTrainer extension.
473        let grad_norm = 1.0_f32;
474        (loss, grad_norm)
475    }
476
477    /// INV-TRAIN-003 discharge: hash the real AdamW (t, m, v) buffers.
478    fn optimizer_state_sha256(&self) -> Option<String> {
479        Some(self.trainer.borrow().optimizer_state_sha256())
480    }
481}
482
483/// `ValFn` impl that runs forward-only across a pre-loaded set of
484/// held-out batches and returns mean cross-entropy loss.
485pub struct RealValFn {
486    trainer: SharedTrainer,
487    held_out: Vec<LMBatch>,
488}
489
490impl RealValFn {
491    pub fn new(trainer: SharedTrainer, held_out: Vec<LMBatch>) -> Self {
492        Self { trainer, held_out }
493    }
494}
495
496impl ValFn for RealValFn {
497    fn validate(&mut self, _epoch: usize) -> f32 {
498        if self.held_out.is_empty() {
499            return f32::NAN;
500        }
501        let trainer = self.trainer.borrow();
502        let mut total_loss = 0.0_f32;
503        let mut total_items = 0_usize;
504        for batch in &self.held_out {
505            for i in 0..batch.batch_size {
506                let Some(inp) = batch.get_input(i) else {
507                    continue;
508                };
509                let Some(tgt) = batch.get_target(i) else {
510                    continue;
511                };
512                let (loss_val, _loss_tensor, _logits) = trainer.forward_single(inp, tgt);
513                total_loss += loss_val;
514                total_items += 1;
515            }
516        }
517        if total_items == 0 {
518            f32::NAN
519        } else {
520            total_loss / total_items as f32
521        }
522    }
523}
524
525/// `CheckpointFn` impl that writes the 370M Llama weights to
526/// `artifact.checkpoint_path` in APR format (task #111 step 7).
527///
528/// Holds the `SharedTrainer` alongside `RealStepFn` / `RealValFn` so
529/// the three hooks see the same in-memory weights.
530pub struct AprCheckpointFn {
531    trainer: SharedTrainer,
532    model_name: String,
533    architecture: String,
534}
535
536impl AprCheckpointFn {
537    pub fn new(
538        trainer: SharedTrainer,
539        model_name: impl Into<String>,
540        architecture: impl Into<String>,
541    ) -> Self {
542        Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
543    }
544}
545
546impl CheckpointFn for AprCheckpointFn {
547    fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
548        let trainer = self.trainer.borrow();
549        trainer
550            .save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
551            .map_err(|e| format!("save_apr failed: {e}"))
552    }
553}
554
555/// Shared-ownership helper so the CLI can hand the same trainer to
556/// both `RealStepFn` and `RealValFn`.
557pub fn build_shared_trainer(lr: f32, seq_length: usize, seed: u64) -> SharedTrainer {
558    let cfg = llama_370m_train_config(lr, seq_length, seed);
559    let trainer = TransformerTrainer::new(cfg);
560    // INV-ARCH-370M-001: verify parameter count lands in the 370M ± 1%
561    // band. This is a debug_assert so release builds do not pay for
562    // the full parameter walk, but dev builds catch drift the instant
563    // any Llama370MConfig constant changes.
564    #[cfg(debug_assertions)]
565    {
566        let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
567        debug_assert!(
568            (366_000_000..=374_000_000).contains(&param_count),
569            "INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
570        );
571    }
572    Rc::new(RefCell::new(trainer))
573}
574
575/// Polymorphic trainer builder for `apr pretrain --init` per
576/// `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature +
577/// §init_load_semantics (PR #1473).
578///
579/// Composes the §50.4 step-5f machinery into a single CLI-callable entry:
580///   - 5c: `build_transformer_config(init_arch)` — polymorphic dispatch
581///   - 5f.1: `validate_pretrain_init_arch_compatible(init_arch)` — encoder rejection
582///   - 5f.2: `load_init_tensors_from_apr(path)` — read APR weights
583///   - 5f.3: `populate_trainer_from_init_tensors(trainer, &tensors)` — populate
584///
585/// Behaviour:
586///   init = None  → identical to `build_shared_trainer` (Llama370M from-scratch
587///                  baseline; INV-ARCH-370M-001 enforced).
588///   init = Some  → builds a trainer with the EXTRACTED arch, validates the
589///                  family, loads tensors from the APR file, populates them.
590///                  INV-ARCH-370M-001 is NOT enforced (the arch is whatever the
591///                  init APR has, e.g. 0.5B / 1.5B / 7B).
592///
593/// Spec: SPEC-SHIP-TWO-001 §52.4 (step 5f.4 wireup).
594///
595/// # Errors
596///
597/// Returns Err when:
598/// - `init_arch` is `Some` with `architecture = Encoder` (FALSIFY-APR-PRETRAIN-ARCH-007)
599/// - `load_init_tensors_from_apr` fails (FALSIFY-APR-PRETRAIN-INIT-006)
600/// - `populate_trainer_from_init_tensors` fails (FALSIFY-APR-PRETRAIN-INIT-007)
601pub fn build_shared_trainer_with_init(
602    lr: f32,
603    seq_length: usize,
604    seed: u64,
605    init_arch: Option<&TransformerConfig>,
606    init_path: Option<&Path>,
607) -> Result<SharedTrainer, String> {
608    if init_arch.is_some() != init_path.is_some() {
609        return Err(format!(
610            "build_shared_trainer_with_init: init_arch and init_path must both be Some \
611             or both None (caller bug; init_arch.is_some()={}, init_path.is_some()={})",
612            init_arch.is_some(),
613            init_path.is_some()
614        ));
615    }
616
617    if let Some(cfg) = init_arch {
618        validate_pretrain_init_arch_compatible(cfg)?;
619    }
620
621    let model_cfg = build_transformer_config(init_arch);
622    let mut train_cfg = TransformerTrainConfig::new(model_cfg);
623    train_cfg.lr = lr;
624    train_cfg.max_seq_len = seq_length;
625    train_cfg.seed = seed;
626    let mut trainer = TransformerTrainer::new(train_cfg);
627
628    // Note: INV-ARCH-370M-001 (param-count band check) lives in
629    // `build_shared_trainer` (the from-scratch CLI path). The polymorphic
630    // builder is shape-agnostic by design — `build_transformer_config(init)`
631    // returns whatever the init APR has (0.5B, 1.5B, 7B, etc), so a single
632    // hardcoded band check would fire-fail on every non-Llama370M init.
633
634    if let Some(path) = init_path {
635        let tensors = load_init_tensors_from_apr(path)?;
636        // SPEC §86 / INV-INIT-ARCH-MATCH-001 — fail-fast on the §86
637        // silent-failure pattern (pre-P0-K APR with wrong arch stamp +
638        // mismatched tensor names → random-init fallback at val_loss ≈ 8.6).
639        // Read the raw metadata.architecture string here (init_arch.hf_architecture
640        // is None for pre-P0-K APRs, which is precisely the §86 case — so the
641        // TransformerConfig isn't sufficient).
642        let raw_metadata_arch = read_apr_metadata_architecture_string(path);
643        validate_init_arch_matches_tensor_evidence(raw_metadata_arch.as_deref(), &tensors)?;
644        populate_trainer_from_init_tensors(trainer.model_mut(), &tensors)?;
645    }
646
647    Ok(Rc::new(RefCell::new(trainer)))
648}
649
650/// SPEC §86 helper — read the raw `architecture` string from an APR v2
651/// metadata block without going through `transformer_config_from_apr_metadata`
652/// (which converts to a `ModelArchitecture` enum and loses the original
653/// string). Used by INV-INIT-ARCH-MATCH-001 to detect the §86 case where
654/// the metadata claims "LlamaForCausalLM" but the tensors are Qwen2-shaped.
655///
656/// Returns `None` on any read / parse failure — the gate caller treats
657/// "no metadata claim" as "skip check" so this is safe.
658fn read_apr_metadata_architecture_string(path: &Path) -> Option<String> {
659    use aprender::format::v2::{AprV2Header, AprV2Metadata, HEADER_SIZE_V2, MAGIC_V2};
660    use std::io::{Read, Seek, SeekFrom};
661    let mut file = std::fs::File::open(path).ok()?;
662    let mut header_buf = [0u8; HEADER_SIZE_V2];
663    file.read_exact(&mut header_buf).ok()?;
664    if header_buf[..4] != MAGIC_V2 {
665        return None;
666    }
667    let header = AprV2Header::from_bytes(&header_buf).ok()?;
668    file.seek(SeekFrom::Start(header.metadata_offset)).ok()?;
669    let mut meta_buf = vec![0u8; header.metadata_size as usize];
670    file.read_exact(&mut meta_buf).ok()?;
671    let metadata = AprV2Metadata::from_json(&meta_buf).ok()?;
672    metadata.architecture
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use crate::train::transformer_trainer::LMBatch;
679
680    /// FALSIFY-APR-PRETRAIN-INIT-006 (read-half) — load_init_tensors_from_apr
681    /// returns Err with a clear message naming the falsifier when the path
682    /// does not exist.
683    ///
684    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.2.
685    #[test]
686    fn load_init_tensors_missing_file_errors_with_falsifier_id() {
687        let tmp = tempfile::TempDir::new().expect("tempdir");
688        let missing = tmp.path().join("does-not-exist.apr");
689        let err =
690            load_init_tensors_from_apr(&missing).expect_err("missing init APR file MUST fail-fast");
691        assert!(
692            err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
693            "error must cite falsifier id (auditability): {err}"
694        );
695        assert!(
696            err.contains("does-not-exist.apr"),
697            "error must name the missing path (operator-experience): {err}"
698        );
699    }
700
701    /// FALSIFY-APR-PRETRAIN-INIT-006 (read-half) — function exists with the
702    /// right signature: `Path -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>>`.
703    /// Discharges the COMPILE-BIND level claim. Live empirical correctness
704    /// requires step 5g (operator-runnable LIVE fine-tune).
705    ///
706    /// Drift-prevention: this test catches a future refactor that changes
707    /// the return type or signature, which would break the §50.4 step 5f.3
708    /// follow-up that reconciles the BTreeMap against trainer parameters.
709    #[test]
710    fn load_init_tensors_signature_compile_bind() {
711        // Verify the function signature compile-binds: takes a Path-like,
712        // returns the right Result type. This is a compile-time check —
713        // if the signature drifts, this test stops compiling.
714        fn _check_signature<F>(_f: F)
715        where
716            F: Fn(&Path) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String>,
717        {
718        }
719        _check_signature(|p| load_init_tensors_from_apr(p));
720    }
721
722    #[test]
723    fn transformer_config_matches_llama_370m_constants() {
724        let cfg = llama_370m_transformer_config();
725        assert_eq!(cfg.hidden_size, Llama370MConfig::HIDDEN_DIM);
726        assert_eq!(cfg.num_hidden_layers, Llama370MConfig::NUM_LAYERS);
727        assert_eq!(cfg.num_attention_heads, Llama370MConfig::NUM_HEADS);
728        assert_eq!(cfg.num_kv_heads, Llama370MConfig::NUM_KV_HEADS);
729        assert_eq!(cfg.intermediate_size, Llama370MConfig::INTERMEDIATE_DIM);
730        assert_eq!(cfg.vocab_size, Llama370MConfig::VOCAB_SIZE);
731        assert!((cfg.rope_theta - Llama370MConfig::ROPE_THETA).abs() < f32::EPSILON);
732        assert!((cfg.rms_norm_eps - Llama370MConfig::RMS_NORM_EPS).abs() < f32::EPSILON);
733        assert!(!cfg.use_bias, "INV-ARCH-370M-008: no bias");
734        assert!(cfg.tie_word_embeddings, "INV-ARCH-370M-004: tied embeddings");
735    }
736
737    /// FALSIFY-APR-PRETRAIN-ARCH-002 — `build_transformer_config(None)` returns
738    /// the Llama370M baseline byte-for-byte. Falsifies regression in the §24/§25
739    /// from-scratch path when the polymorphic dispatch was added.
740    ///
741    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
742    #[test]
743    fn build_transformer_config_no_init_matches_llama370m() {
744        let baseline = llama_370m_transformer_config();
745        let result = build_transformer_config(None);
746        assert_eq!(result.hidden_size, baseline.hidden_size);
747        assert_eq!(result.num_attention_heads, baseline.num_attention_heads);
748        assert_eq!(result.num_kv_heads, baseline.num_kv_heads);
749        assert_eq!(result.intermediate_size, baseline.intermediate_size);
750        assert_eq!(result.num_hidden_layers, baseline.num_hidden_layers);
751        assert_eq!(result.vocab_size, baseline.vocab_size);
752        assert_eq!(result.max_position_embeddings, baseline.max_position_embeddings);
753        assert!((result.rms_norm_eps - baseline.rms_norm_eps).abs() < f32::EPSILON);
754        assert!((result.rope_theta - baseline.rope_theta).abs() < f32::EPSILON);
755        assert_eq!(result.use_bias, baseline.use_bias);
756        assert_eq!(result.tie_word_embeddings, baseline.tie_word_embeddings);
757        assert_eq!(result.architecture, baseline.architecture);
758        assert_eq!(result.hf_architecture, baseline.hf_architecture);
759        assert_eq!(result.hf_model_type, baseline.hf_model_type);
760    }
761
762    /// FALSIFY-APR-PRETRAIN-ARCH-003 — `build_transformer_config(Some(cfg))`
763    /// passes through the caller-provided config byte-for-byte. No silent
764    /// defaults, no field overrides. Tests with Qwen2.5-Coder-0.5B shape
765    /// because that is the §49 fine-tune target.
766    ///
767    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
768    #[test]
769    fn build_transformer_config_qwen_init_matches_input() {
770        let qwen = TransformerConfig::qwen2_0_5b();
771        let result = build_transformer_config(Some(&qwen));
772        assert_eq!(result.hidden_size, qwen.hidden_size, "hidden_size");
773        assert_eq!(result.num_attention_heads, qwen.num_attention_heads, "num_attention_heads");
774        assert_eq!(result.num_kv_heads, qwen.num_kv_heads, "num_kv_heads");
775        assert_eq!(result.intermediate_size, qwen.intermediate_size, "intermediate_size");
776        assert_eq!(result.num_hidden_layers, qwen.num_hidden_layers, "num_hidden_layers");
777        assert_eq!(result.vocab_size, qwen.vocab_size, "vocab_size");
778        assert_eq!(
779            result.max_position_embeddings, qwen.max_position_embeddings,
780            "max_position_embeddings"
781        );
782        assert_eq!(result.use_bias, qwen.use_bias, "use_bias");
783        assert_eq!(result.tie_word_embeddings, qwen.tie_word_embeddings, "tie_word_embeddings");
784        assert_eq!(result.architecture, qwen.architecture, "architecture");
785        // GQA-7:1 ratio preserved (Qwen2.5-0.5B: 14 / 2 = 7)
786        assert_eq!(
787            result.num_attention_heads / result.num_kv_heads,
788            7,
789            "GQA ratio must preserve as 7:1 (Qwen2.5-0.5B canonical)"
790        );
791    }
792
793    /// Drift-prevention: dispatch is mutually exclusive — None and Some
794    /// produce different configs (otherwise the polymorphic builder is
795    /// vacuous). Catches a future refactor that accidentally always
796    /// returns Llama370M regardless of init.
797    ///
798    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c — drift prevention.
799    #[test]
800    fn build_transformer_config_dispatch_mutually_exclusive() {
801        let qwen = TransformerConfig::qwen2_0_5b();
802        let none_result = build_transformer_config(None);
803        let some_result = build_transformer_config(Some(&qwen));
804        // The two outputs MUST differ, otherwise the dispatch is broken.
805        assert_ne!(
806            none_result.hidden_size, some_result.hidden_size,
807            "dispatch must differentiate None vs Some — Llama370M hidden=1024 vs Qwen=896"
808        );
809        assert_ne!(
810            none_result.vocab_size, some_result.vocab_size,
811            "dispatch must differentiate None vs Some — Llama370M vocab=50257 vs Qwen=151936"
812        );
813    }
814
815    /// FALSIFY-APR-PRETRAIN-ARCH-007 (decoder branch) — `validate_pretrain_init_arch_compatible`
816    /// returns Ok for a decoder-family config.
817    ///
818    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
819    #[test]
820    fn validate_pretrain_init_arch_accepts_decoder() {
821        let qwen = TransformerConfig::qwen2_0_5b();
822        assert_eq!(qwen.architecture, ModelArchitecture::Decoder);
823        validate_pretrain_init_arch_compatible(&qwen)
824            .expect("decoder-family config (Qwen2.5-0.5B) MUST pass arch-compat gate");
825    }
826
827    /// FALSIFY-APR-PRETRAIN-ARCH-007 (encoder branch) — load-bearing test.
828    /// `validate_pretrain_init_arch_compatible` returns Err naming the
829    /// architecture-family mismatch when given an encoder config (e.g.,
830    /// CodeBERT). Without this gate, the decoder trainer would silently
831    /// build with encoder weights producing nonsense gradients.
832    ///
833    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
834    #[test]
835    fn validate_pretrain_init_arch_rejects_encoder() {
836        // Construct a minimal encoder config (CodeBERT-shaped).
837        let bert = TransformerConfig {
838            hidden_size: 768,
839            num_attention_heads: 12,
840            num_kv_heads: 12,
841            intermediate_size: 3072,
842            num_hidden_layers: 12,
843            vocab_size: 50265,
844            max_position_embeddings: 514,
845            rms_norm_eps: 1e-12,
846            rope_theta: 10_000.0,
847            use_bias: true,
848            head_dim_override: None,
849            architecture: ModelArchitecture::Encoder,
850            hf_architecture: Some("RobertaModel".to_string()),
851            hf_model_type: Some("roberta".to_string()),
852            tie_word_embeddings: false,
853        };
854        let err = validate_pretrain_init_arch_compatible(&bert).expect_err(
855            "encoder-family config (CodeBERT/RoBERTa) MUST fail arch-compat gate — \
856             silent acceptance would corrupt §49 fine-tune trajectory before any \
857             FALSIFY-006 check could measure it",
858        );
859        assert!(
860            err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
861            "error must cite falsifier id: {err}"
862        );
863        assert!(err.contains("Encoder"), "error must name the architecture family: {err}");
864        assert!(
865            err.contains("decoder-only"),
866            "error must explain why this is wrong (decoder trainer): {err}"
867        );
868        assert!(
869            err.contains("RobertaModel"),
870            "error must name the offending hf_architecture: {err}"
871        );
872    }
873
874    /// Drift-prevention: validate_pretrain_init_arch_compatible's behavior on
875    /// the from-scratch baseline (Llama370M) — must Ok. Catches a future
876    /// refactor that accidentally over-rejects decoder configs.
877    #[test]
878    fn validate_pretrain_init_arch_accepts_llama370m_baseline() {
879        let llama = llama_370m_transformer_config();
880        assert_eq!(
881            llama.architecture,
882            ModelArchitecture::Decoder,
883            "Llama370M baseline MUST be Decoder (regression-free)"
884        );
885        validate_pretrain_init_arch_compatible(&llama)
886            .expect("Llama370M baseline (Decoder) MUST pass arch-compat gate");
887    }
888
889    #[test]
890    fn real_step_fn_exhausted_iterator_returns_finite_placeholder() {
891        // Empty iterator means no real batches; we must still return
892        // finite values so the loop's non-divergence + NaN guards see
893        // sane data instead of surprising NaN.
894        //
895        // Construct a minimal trainer WITHOUT running `build_shared_trainer`
896        // because that takes ~5 GB of parameter allocation for 370M —
897        // too expensive for a unit test. Use a tiny synthetic config.
898        let mut tiny = TransformerConfig::llama2_7b();
899        tiny.hidden_size = 64;
900        tiny.num_attention_heads = 4;
901        tiny.num_kv_heads = 4;
902        tiny.num_hidden_layers = 2;
903        tiny.intermediate_size = 128;
904        tiny.vocab_size = 256;
905        let cfg = TransformerTrainConfig::new(tiny);
906        let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
907        let empty_iter: Box<dyn Iterator<Item = LMBatch>> = Box::new(std::iter::empty::<LMBatch>());
908        let mut step = RealStepFn::new(trainer, empty_iter);
909        let (loss, grad_norm) = step.step(0, 1.0e-4, 128);
910        assert!(loss.is_finite(), "exhausted iter must return finite loss");
911        assert!(grad_norm.is_finite(), "grad_norm must be finite");
912        assert!(grad_norm >= 0.0, "INV-TRAIN-008: grad_norm non-negative");
913    }
914
915    #[test]
916    fn real_val_fn_empty_held_out_returns_nan() {
917        let mut tiny = TransformerConfig::llama2_7b();
918        tiny.hidden_size = 64;
919        tiny.num_attention_heads = 4;
920        tiny.num_kv_heads = 4;
921        tiny.num_hidden_layers = 2;
922        tiny.intermediate_size = 128;
923        tiny.vocab_size = 256;
924        let cfg = TransformerTrainConfig::new(tiny);
925        let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
926        let mut val = RealValFn::new(trainer, Vec::new());
927        let loss = val.validate(0);
928        assert!(loss.is_nan(), "empty held_out must surface as NaN to the guard");
929    }
930
931    /// Build a tiny Transformer suitable for unit testing the populate path.
932    /// Uses GQA-1:1 (kv=q) shape — the populate function is shape-agnostic so
933    /// the simpler ratio is fine here.
934    fn tiny_test_transformer() -> Transformer {
935        let mut tiny = TransformerConfig::llama2_7b();
936        tiny.hidden_size = 32;
937        tiny.num_attention_heads = 2;
938        tiny.num_kv_heads = 2;
939        tiny.num_hidden_layers = 2;
940        tiny.intermediate_size = 64;
941        tiny.vocab_size = 16;
942        Transformer::new(&tiny)
943    }
944
945    /// Build a `BTreeMap<String, (Vec<f32>, Vec<usize>)>` from a Transformer's
946    /// `named_parameters()` snapshot. Each tensor is a deterministic ramp
947    /// (i as f32 * 0.001) so populate is byte-identifiable post-set.
948    fn tensors_map_from_transformer(
949        transformer: &Transformer,
950    ) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
951        let mut map = BTreeMap::new();
952        for (name, t) in transformer.named_parameters() {
953            let len = t.len();
954            let data: Vec<f32> = (0..len).map(|i| i as f32 * 0.001).collect();
955            map.insert(name, (data, vec![len]));
956        }
957        map
958    }
959
960    /// Happy path — every model parameter has a matching init entry of correct
961    /// length; populate succeeds and the count matches `named_parameters().len()`.
962    #[test]
963    fn populate_trainer_from_init_tensors_happy_path() {
964        let mut transformer = tiny_test_transformer();
965        let init_tensors = tensors_map_from_transformer(&transformer);
966        let expected_count = transformer.named_parameters().len();
967        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
968        assert!(result.is_ok(), "happy-path populate must succeed: {result:?}");
969        assert_eq!(
970            result.unwrap(),
971            expected_count,
972            "populated count must equal named_parameters().len()"
973        );
974    }
975
976    /// Drift-prevention: extra entries in `init_tensors` that the model does
977    /// NOT expose are silently ignored. This handles tied-embeddings: a Qwen
978    /// APR may publish a separate `lm_head.weight` that the trainer's tied
979    /// model omits.
980    #[test]
981    fn populate_trainer_from_init_tensors_extra_entries_silently_ignored() {
982        let mut transformer = tiny_test_transformer();
983        let mut init_tensors = tensors_map_from_transformer(&transformer);
984        // Inject a fictitious extra parameter that the model does not have.
985        init_tensors
986            .insert("model.layers.999.fictitious.weight".to_string(), (vec![0.0; 4], vec![4]));
987        let expected_count = transformer.named_parameters().len();
988        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
989        assert!(result.is_ok(), "extra init entries must NOT cause Err: {result:?}");
990        assert_eq!(result.unwrap(), expected_count);
991    }
992
993    /// FALSIFY-APR-PRETRAIN-INIT-007 (length mismatch) — when an init tensor
994    /// has the wrong flat length for a known parameter, populate MUST Err
995    /// with the FALSIFIER ID and a per-parameter diagnostic line.
996    #[test]
997    fn populate_trainer_from_init_tensors_rejects_length_mismatch() {
998        let mut transformer = tiny_test_transformer();
999        let mut init_tensors = tensors_map_from_transformer(&transformer);
1000        // Corrupt one entry's length to trigger the mismatch path.
1001        let any_name = transformer.named_parameters()[0].0.clone();
1002        init_tensors.insert(any_name.clone(), (vec![0.0; 7], vec![7]));
1003        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
1004        assert!(result.is_err(), "length-mismatch must Err, not silently truncate");
1005        let err = result.unwrap_err();
1006        assert!(
1007            err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
1008            "error must cite falsifier id; got: {err}"
1009        );
1010        assert!(err.contains(&any_name), "error must name the offending parameter; got: {err}");
1011        assert!(
1012            err.contains("init length 7"),
1013            "error must report the actual init length; got: {err}"
1014        );
1015    }
1016
1017    /// FALSIFY-APR-PRETRAIN-INIT-007 (missing-required) — when a model
1018    /// parameter has NO corresponding entry in `init_tensors`, populate MUST
1019    /// Err with FALSIFIER ID and a "not present in init APR tensors"
1020    /// per-parameter diagnostic. This catches the architecture-mismatch
1021    /// class where init was extracted from a different model family.
1022    #[test]
1023    fn populate_trainer_from_init_tensors_rejects_missing_required_param() {
1024        let mut transformer = tiny_test_transformer();
1025        let mut init_tensors = tensors_map_from_transformer(&transformer);
1026        // Drop one entry to trigger the missing-required path.
1027        let any_name = transformer.named_parameters()[0].0.clone();
1028        init_tensors.remove(&any_name);
1029        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
1030        assert!(result.is_err(), "missing-required must Err, not silently leave random init");
1031        let err = result.unwrap_err();
1032        assert!(
1033            err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
1034            "error must cite falsifier id; got: {err}"
1035        );
1036        assert!(err.contains(&any_name), "error must name the missing parameter; got: {err}");
1037        assert!(
1038            err.contains("not present in init APR"),
1039            "error must say what was missing; got: {err}"
1040        );
1041    }
1042
1043    /// `build_shared_trainer_with_init(None, None)` returns a trainer with
1044    /// the §24/§25 from-scratch Llama370M architecture (regression-free
1045    /// dispatch). Asserts the baseline shape via the (hidden, vocab) tuple
1046    /// rather than param count to avoid the stale INV-ARCH-370M-001 band
1047    /// check in `build_shared_trainer` (a defect outside §50.4 scope —
1048    /// param_count=322M vs assert range [366M, 374M]; tracked for follow-up).
1049    #[test]
1050    fn build_shared_trainer_with_init_none_uses_llama370m_shape() {
1051        let trainer = build_shared_trainer_with_init(1.0e-4, 128, 42, None, None)
1052            .expect("None case must succeed");
1053        let model = trainer.borrow();
1054        // The baseline polymorphic dispatch produces a Llama370M-shaped model.
1055        // Embedding shape `vocab × hidden` is the cleanest non-stale check.
1056        let embed_len = model.model().named_parameters()[0].1.len();
1057        let expected_embed_len = Llama370MConfig::VOCAB_SIZE * Llama370MConfig::HIDDEN_DIM;
1058        assert_eq!(
1059            embed_len,
1060            expected_embed_len,
1061            "init=None must produce Llama370M-shaped embedding (vocab={} × hidden={})",
1062            Llama370MConfig::VOCAB_SIZE,
1063            Llama370MConfig::HIDDEN_DIM
1064        );
1065    }
1066
1067    /// `build_shared_trainer_with_init(Some, None)` and the inverse must
1068    /// fail-fast — both args are paired and either both Some or both None.
1069    /// Drift-prevention: catches a future caller that forgets to pass one.
1070    #[test]
1071    fn build_shared_trainer_with_init_rejects_unpaired_args() {
1072        // arch Some, path None
1073        let cfg = TransformerConfig::qwen2_0_5b();
1074        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), None);
1075        assert!(result.is_err(), "unpaired (arch=Some, path=None) must Err");
1076        // arch None, path Some
1077        let dummy_path = std::path::PathBuf::from("/dev/null");
1078        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, None, Some(&dummy_path));
1079        assert!(result.is_err(), "unpaired (arch=None, path=Some) must Err");
1080    }
1081
1082    /// `build_shared_trainer_with_init(Some(encoder), Some(path))` rejects
1083    /// the encoder family BEFORE attempting tensor load. Drift-prevention for
1084    /// FALSIFY-APR-PRETRAIN-ARCH-007 at the trainer-builder integration level.
1085    #[test]
1086    fn build_shared_trainer_with_init_rejects_encoder_family() {
1087        let mut encoder_cfg = TransformerConfig::qwen2_0_5b();
1088        encoder_cfg.architecture = ModelArchitecture::Encoder;
1089        let dummy_path = std::path::PathBuf::from("/nonexistent/encoder.apr");
1090        let result =
1091            build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&encoder_cfg), Some(&dummy_path));
1092        let err = match result {
1093            Ok(_) => panic!("encoder family must be rejected before tensor load"),
1094            Err(e) => e,
1095        };
1096        assert!(
1097            err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
1098            "error must cite falsifier id; got: {err}"
1099        );
1100    }
1101
1102    /// `build_shared_trainer_with_init(Some(decoder), Some(missing_path))`
1103    /// proceeds past the family check and FAILS at tensor load with a
1104    /// FALSIFY-006 error. Pins the failure ordering: arch validation first,
1105    /// then tensor load.
1106    #[test]
1107    fn build_shared_trainer_with_init_decoder_family_proceeds_to_tensor_load() {
1108        let cfg = TransformerConfig::qwen2_0_5b();
1109        let dummy_path = std::path::PathBuf::from("/nonexistent/decoder.apr");
1110        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), Some(&dummy_path));
1111        let err = match result {
1112            Ok(_) => panic!("missing tensor path must Err"),
1113            Err(e) => e,
1114        };
1115        assert!(
1116            err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
1117            "decoder family proceeds to tensor load; failure cites INIT-006 not ARCH-007; got: {err}"
1118        );
1119        assert!(
1120            !err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
1121            "decoder family must NOT trigger encoder-rejection; got: {err}"
1122        );
1123    }
1124
1125    /// FALSIFY-H4-INIT-STATS-001 (SHIP-TWO §61 H4A bisect):
1126    /// `load_init_tensors_from_apr` on the canonical Qwen2.5-Coder-0.5B-Instruct
1127    /// APR file MUST produce sensibly-distributed weights:
1128    ///   - `model.embed_tokens.weight` mean ≈ 0 (within ±0.01)
1129    ///   - `model.embed_tokens.weight` std in [0.01, 0.1] (HF LLaMA init = 0.02)
1130    ///   - `model.norm.weight` mean ≈ 1.0 (RMSNorm pretrained scale)
1131    ///
1132    /// CONTEXT: §61 evidence shows val_loss=19.80 > ln(vocab)=17.21 at
1133    /// step 1, indicating the loaded model produces sub-random predictions.
1134    /// Four candidate hypotheses (H4A tied weights, H4B layout, H4C norm
1135    /// scale, H4D residual stream). This test bisects H4A+H4C: if any of
1136    /// the loaded tensor stats are wildly out-of-range, the load itself
1137    /// is corrupt; if all stats look correct, the bug is in the forward
1138    /// path (H4B layout or H4D residual).
1139    ///
1140    /// Host-gated: requires a canonical Qwen 0.5B init APR. Tries the
1141    /// "fresh" path first (current `apr import` of HF safetensors,
1142    /// preserves BF16 dtype tag); falls back to the older "fp16" path
1143    /// (legacy import, mis-tagged as F16). Skips if neither present.
1144    ///
1145    /// The legacy file demonstrates the H4 dtype-mislabel defect class:
1146    /// safetensors source is BF16, old `apr import` wrote bytes raw
1147    /// but tagged dtype as F16, aprender's loader then read bytes as
1148    /// F16 and produced distorted values. The fresh path preserves
1149    /// BF16 correctly. Element-0 cross-checks agree with the
1150    /// safetensors source under BF16 decode.
1151    #[test]
1152    fn falsify_h4_init_stats_qwen_embed_norm_sensible() {
1153        let fresh = std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-fresh.apr");
1154        let legacy =
1155            std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-instruct-fp16.apr");
1156        let path = if fresh.exists() {
1157            fresh
1158        } else if legacy.exists() {
1159            legacy
1160        } else {
1161            eprintln!("[falsify-h4-init-stats-001] skipping: host lacks Qwen 0.5B APR");
1162            return;
1163        };
1164        let _ = path; // silence unused if branches
1165        if !path.exists() {
1166            eprintln!("[falsify-h4-init-stats-001] skipping: host lacks {}", path.display());
1167            return;
1168        }
1169        // H4 root-cause probe: directly inspect the APR's dtype tag to
1170        // verify whether the F16 vs BF16 distinction was preserved
1171        // through `apr import`.
1172        {
1173            use aprender::format::v2::AprV2Reader;
1174            let bytes = std::fs::read(path).expect("read APR");
1175            let reader = AprV2Reader::from_bytes(&bytes).expect("parse APR v2");
1176            for name in ["model.layers.0.self_attn.q_proj.bias", "model.norm.weight"] {
1177                if let Some(entry) = reader.get_tensor(name) {
1178                    eprintln!(
1179                        "[h4-init-dtype] {name}: dtype={:?} shape={:?}",
1180                        entry.dtype, entry.shape
1181                    );
1182                }
1183            }
1184        }
1185        let tensors = match load_init_tensors_from_apr(path) {
1186            Ok(t) => t,
1187            Err(e) => {
1188                panic!("FALSIFY-H4-INIT-STATS-001: load_init_tensors_from_apr failed: {e}");
1189            }
1190        };
1191
1192        // Required tensors
1193        let embed = tensors
1194            .get("model.embed_tokens.weight")
1195            .unwrap_or_else(|| panic!("missing model.embed_tokens.weight in init APR"));
1196        let norm = tensors
1197            .get("model.norm.weight")
1198            .unwrap_or_else(|| panic!("missing model.norm.weight in init APR"));
1199
1200        let stats = |name: &str, data: &[f32]| -> (f64, f64, f32, f32) {
1201            let n = data.len() as f64;
1202            let mean = data.iter().map(|&v| v as f64).sum::<f64>() / n;
1203            let var = data
1204                .iter()
1205                .map(|&v| {
1206                    let d = v as f64 - mean;
1207                    d * d
1208                })
1209                .sum::<f64>()
1210                / n;
1211            let std = var.sqrt();
1212            let min = data.iter().copied().fold(f32::INFINITY, f32::min);
1213            let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1214            eprintln!(
1215                "[h4-init-stats] {name}: n={n} mean={mean:.5} std={std:.5} min={min:.4} max={max:.4}"
1216            );
1217            (mean, std, min, max)
1218        };
1219        // H4-DTYPE-MISLABEL: dump first 4 element-0 values to compare
1220        // with safetensors source (decoded as BF16). If the APR loader
1221        // mis-decodes BF16 bytes as F16, values will diverge.
1222        {
1223            let q = tensors.get("model.layers.0.self_attn.q_proj.bias").unwrap();
1224            eprintln!(
1225                "[h4-dtype-mislabel] q_proj.bias L0[0..6] (aprender F16-decoded): {:?}",
1226                &q.0[..6]
1227            );
1228            let n = tensors.get("model.norm.weight").unwrap();
1229            eprintln!(
1230                "[h4-dtype-mislabel] model.norm.weight[0..6] (aprender F16-decoded): {:?}",
1231                &n.0[..6]
1232            );
1233        }
1234
1235        let (em, es, _, _) = stats("model.embed_tokens.weight", &embed.0);
1236        let (nm, ns, _, _) = stats("model.norm.weight", &norm.0);
1237
1238        // H4C bisect: dump per-layer norm stats. Standard RMSNorm
1239        // weights are near 1.0 (init=1.0, trained drift to ~0.1-2.0).
1240        // Mean > 5 across layers indicates a load-time scale-corruption.
1241        for layer_idx in [0_usize, 5, 11, 23] {
1242            for kind in ["input_layernorm", "post_attention_layernorm"] {
1243                let key = format!("model.layers.{layer_idx}.{kind}.weight");
1244                if let Some(t) = tensors.get(&key) {
1245                    stats(&key, &t.0);
1246                }
1247            }
1248        }
1249        for kind in [
1250            "model.layers.0.self_attn.q_proj.weight",
1251            "model.layers.0.self_attn.q_proj.bias",
1252            "model.layers.0.mlp.gate_proj.weight",
1253            "model.layers.0.mlp.down_proj.weight",
1254        ] {
1255            if let Some(t) = tensors.get(kind) {
1256                stats(kind, &t.0);
1257            }
1258        }
1259
1260        // Embedding init bound: HF LLaMA init normal(0, 0.02). After
1261        // pretraining the std grows but typically stays in [0.01, 0.1].
1262        // mean should be near 0 (well-centered).
1263        assert!(
1264            em.abs() < 0.05,
1265            "FALSIFY-H4-INIT-STATS-001: embed mean={em} > 0.05; weights are not centered. \
1266             Possible f16→f32 sign-bit corruption or wrong byte-order."
1267        );
1268        assert!(
1269            (0.005..=0.5).contains(&es),
1270            "FALSIFY-H4-INIT-STATS-001: embed std={es} outside [0.005, 0.5]; weights are not \
1271             distributed like trained transformer init. Possible f16 mantissa misread or \
1272             scale corruption."
1273        );
1274
1275        // RMSNorm init: weights are ~1.0 (sqrt(2)≈1.41 in some configs).
1276        // After training they stay close to 1, sometimes drifting up to ~10.
1277        assert!(
1278            nm > 0.01 && nm < 100.0,
1279            "FALSIFY-H4-INIT-STATS-001: norm mean={nm} outside [0.01, 100]; RMSNorm scale \
1280             load is corrupt. Trained pretrained values are typically near 1.0."
1281        );
1282        assert!(
1283            ns < 100.0,
1284            "FALSIFY-H4-INIT-STATS-001: norm std={ns} > 100; RMSNorm has explosive variance. \
1285             Tensor load is corrupt."
1286        );
1287    }
1288
1289    /// FALSIFY-H4-CPU-FORWARD-001 (H4 residual cascade — bisect to CPU vs CUDA):
1290    /// CPU `aprender::Transformer::forward` on a populated Qwen 0.5B model
1291    /// MUST produce sensibly-distributed logits. Host-gated test that
1292    /// bisects whether the val_loss > ln(vocab) defect is in the
1293    /// populate path / CPU forward (RED here = bug there) or in CUDA
1294    /// (GREEN here, RED in eval_batch = bug in CUDA path).
1295    #[test]
1296    fn falsify_h4_cpu_forward_qwen_logits_sensible() {
1297        let fresh = std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-fresh.apr");
1298        let legacy =
1299            std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-instruct-fp16.apr");
1300        let path = if fresh.exists() {
1301            fresh
1302        } else if legacy.exists() {
1303            legacy
1304        } else {
1305            eprintln!("[falsify-h4-cpu-forward-001] skipping: host lacks Qwen 0.5B APR");
1306            return;
1307        };
1308
1309        let tensors = load_init_tensors_from_apr(path).expect("load_init_tensors_from_apr");
1310        let cfg = TransformerConfig::qwen2_0_5b();
1311        let mut transformer = Transformer::new(&cfg);
1312        let populated = populate_trainer_from_init_tensors(&mut transformer, &tensors)
1313            .expect("populate_trainer_from_init_tensors");
1314        eprintln!("[falsify-h4-cpu-forward-001] populated {populated} tensors");
1315
1316        let token_ids = vec![100_u32];
1317        let logits = transformer.forward(&token_ids);
1318        let data = logits.data();
1319        let slice = data.as_slice().expect("logits contiguous");
1320
1321        let mut nan_count = 0usize;
1322        let mut inf_count = 0usize;
1323        let mut min = f32::INFINITY;
1324        let mut max = f32::NEG_INFINITY;
1325        let mut sum = 0.0_f64;
1326        let mut sum_sq = 0.0_f64;
1327        let mut argmax_idx = 0_usize;
1328        for (i, &v) in slice.iter().enumerate() {
1329            if v.is_nan() {
1330                nan_count += 1;
1331            } else if v.is_infinite() {
1332                inf_count += 1;
1333            } else {
1334                if v < min {
1335                    min = v;
1336                }
1337                if v > max {
1338                    max = v;
1339                    argmax_idx = i;
1340                }
1341                sum += v as f64;
1342                sum_sq += (v as f64) * (v as f64);
1343            }
1344        }
1345        let n = slice.len() as f64;
1346        let mean = sum / n;
1347        let std = (sum_sq / n - mean * mean).sqrt();
1348
1349        eprintln!(
1350            "[falsify-h4-cpu-forward-001] token=100 logits: n={} nan={nan_count} inf={inf_count} \
1351             min={min:.4} max={max:.4} mean={mean:.4} std={std:.4} argmax={argmax_idx}",
1352            slice.len()
1353        );
1354
1355        assert_eq!(nan_count, 0, "logits contain NaN — forward corruption");
1356        assert_eq!(inf_count, 0, "logits contain Inf — forward corruption");
1357        assert!(
1358            std > 0.01,
1359            "FALSIFY-H4-CPU-FORWARD-001: logits std={std} < 0.01 — essentially constant"
1360        );
1361        let peak_to_mean = (max as f64 - mean).abs() / std.max(1e-9);
1362        assert!(
1363            peak_to_mean > 1.5,
1364            "FALSIFY-H4-CPU-FORWARD-001: peak-to-mean ratio = {peak_to_mean} < 1.5 — \
1365             logits are essentially uniform"
1366        );
1367        assert!(
1368            (argmax_idx as u32) < cfg.vocab_size as u32,
1369            "FALSIFY-H4-CPU-FORWARD-001: argmax_idx={argmax_idx} >= vocab_size={}",
1370            cfg.vocab_size
1371        );
1372    }
1373
1374    // ========================================================================
1375    // SPEC §86 / INV-INIT-ARCH-MATCH-001 unit tests — arch-mismatch fail-fast
1376    // ========================================================================
1377
1378    fn qwen2_tensor_names() -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
1379        // Minimal Qwen2 signature: model.layers + self_attn.q_proj.bias (distinguishes from Llama)
1380        let mut m = BTreeMap::new();
1381        m.insert("model.layers.0.self_attn.q_proj.bias".to_string(), (vec![0.0_f32; 4], vec![4]));
1382        m.insert(
1383            "model.layers.0.self_attn.q_proj.weight".to_string(),
1384            (vec![0.0_f32; 16], vec![4, 4]),
1385        );
1386        m
1387    }
1388
1389    fn llama_tensor_names() -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
1390        // Llama signature: model.layers + NO attention bias + NO qkv_proj
1391        let mut m = BTreeMap::new();
1392        m.insert(
1393            "model.layers.0.self_attn.q_proj.weight".to_string(),
1394            (vec![0.0_f32; 16], vec![4, 4]),
1395        );
1396        m.insert("model.layers.0.input_layernorm.weight".to_string(), (vec![1.0_f32; 4], vec![4]));
1397        m
1398    }
1399
1400    /// SPEC §86 INV-INIT-ARCH-MATCH-001: the canonical §86 case — APR
1401    /// metadata claims "LlamaForCausalLM" (§82 P0-H fallback) but tensors
1402    /// are Qwen2-shaped (have q_proj.bias). MUST fail with the falsifier ID.
1403    #[test]
1404    fn inv_init_arch_match_001_rejects_llama_stamped_qwen2_tensors() {
1405        let tensors = qwen2_tensor_names();
1406        let err = validate_init_arch_matches_tensor_evidence(Some("LlamaForCausalLM"), &tensors)
1407            .expect_err("§86 case MUST be rejected");
1408        assert!(
1409            err.contains("FALSIFY-INIT-ARCH-MATCH-001"),
1410            "error must cite falsifier id; got: {err}"
1411        );
1412        assert!(
1413            err.contains("llama") && err.contains("qwen2"),
1414            "error must name both claimed and inferred families; got: {err}"
1415        );
1416        assert!(
1417            err.contains("apr stamp"),
1418            "error must include the §86.4 salvage recipe; got: {err}"
1419        );
1420    }
1421
1422    /// SPEC §86: the inverse — metadata claims "Qwen2ForCausalLM" but
1423    /// tensors are Llama-shaped (no q_proj.bias). MUST fail.
1424    #[test]
1425    fn inv_init_arch_match_001_rejects_qwen2_stamped_llama_tensors() {
1426        let tensors = llama_tensor_names();
1427        let err = validate_init_arch_matches_tensor_evidence(Some("Qwen2ForCausalLM"), &tensors)
1428            .expect_err("inverse §86 case MUST be rejected");
1429        assert!(err.contains("FALSIFY-INIT-ARCH-MATCH-001"));
1430        assert!(err.contains("qwen2") && err.contains("llama"));
1431    }
1432
1433    /// SPEC §86: matching family slug + Qwen2 tensors — must PASS (no false-positive).
1434    #[test]
1435    fn inv_init_arch_match_001_accepts_matching_qwen2() {
1436        let tensors = qwen2_tensor_names();
1437        validate_init_arch_matches_tensor_evidence(Some("Qwen2ForCausalLM"), &tensors)
1438            .expect("matching qwen2 + qwen2 must pass");
1439        validate_init_arch_matches_tensor_evidence(Some("qwen2"), &tensors)
1440            .expect("matching qwen2 slug + qwen2 tensors must pass");
1441    }
1442
1443    /// SPEC §86: matching family slug + Llama tensors — must PASS.
1444    #[test]
1445    fn inv_init_arch_match_001_accepts_matching_llama() {
1446        let tensors = llama_tensor_names();
1447        validate_init_arch_matches_tensor_evidence(Some("LlamaForCausalLM"), &tensors)
1448            .expect("matching llama + llama must pass");
1449        validate_init_arch_matches_tensor_evidence(Some("llama"), &tensors)
1450            .expect("matching llama slug + llama tensors must pass");
1451    }
1452
1453    /// SPEC §86: None metadata claim — skip the check (no false-positive
1454    /// on novel architectures).
1455    #[test]
1456    fn inv_init_arch_match_001_skips_when_metadata_absent() {
1457        let tensors = qwen2_tensor_names();
1458        validate_init_arch_matches_tensor_evidence(None, &tensors)
1459            .expect("absent metadata claim must skip check");
1460    }
1461
1462    /// SPEC §86: unknown family in metadata (e.g., "weird-novel-arch") —
1463    /// skip the check.
1464    #[test]
1465    fn inv_init_arch_match_001_skips_unmappable_metadata() {
1466        let tensors = qwen2_tensor_names();
1467        validate_init_arch_matches_tensor_evidence(Some("WeirdNovelArchForCausalLM"), &tensors)
1468            .expect("unmappable metadata MUST skip check (no false-positive on novel arch)");
1469    }
1470
1471    /// SPEC §86: GGUF-style tensor names (blk.*) — inference returns
1472    /// "unknown" and we trust the metadata claim. Must not fail.
1473    #[test]
1474    fn inv_init_arch_match_001_trusts_metadata_when_tensors_unknown() {
1475        let mut tensors = BTreeMap::new();
1476        tensors.insert("blk.0.attn_q.weight".to_string(), (vec![0.0_f32; 16], vec![4, 4]));
1477        // GGUF names can't disambiguate; we trust the metadata.
1478        validate_init_arch_matches_tensor_evidence(Some("LlamaForCausalLM"), &tensors)
1479            .expect("unknown tensor family must skip check (trust metadata)");
1480    }
1481
1482    /// Spec §86 helper test: family_from_tensor_names correctly
1483    /// distinguishes Qwen2 from Llama by the attention-bias signal.
1484    #[test]
1485    fn family_from_tensor_names_distinguishes_qwen2_from_llama() {
1486        let qwen2: Vec<&str> = vec![
1487            "model.layers.0.self_attn.q_proj.weight",
1488            "model.layers.0.self_attn.q_proj.bias", // bias = Qwen2 signature
1489        ];
1490        assert_eq!(family_from_tensor_names(qwen2.iter().copied()), "qwen2");
1491
1492        let llama: Vec<&str> =
1493            vec!["model.layers.0.self_attn.q_proj.weight", "model.layers.0.input_layernorm.weight"];
1494        assert_eq!(family_from_tensor_names(llama.iter().copied()), "llama");
1495    }
1496
1497    /// Spec §86: normalize_metadata_arch_family handles all three input forms.
1498    #[test]
1499    fn normalize_metadata_arch_family_handles_three_forms() {
1500        // Class name (P0-H fallback)
1501        assert_eq!(normalize_metadata_arch_family("Qwen2ForCausalLM"), Some("qwen2"));
1502        assert_eq!(normalize_metadata_arch_family("LlamaForCausalLM"), Some("llama"));
1503        // Family slug (canonical)
1504        assert_eq!(normalize_metadata_arch_family("qwen2"), Some("qwen2"));
1505        assert_eq!(normalize_metadata_arch_family("llama"), Some("llama"));
1506        // Capitalised legacy
1507        assert_eq!(normalize_metadata_arch_family("Qwen2"), Some("qwen2"));
1508        // Unknown
1509        assert_eq!(normalize_metadata_arch_family("unknown"), None);
1510        assert_eq!(normalize_metadata_arch_family("WeirdNovelArch"), None);
1511    }
1512}