Skip to main content

entrenar/train/
pretrain_real.rs

1//! Real-corpus `StepFn` / `ValFn` for MODEL-2 pretrain MVP (task #111).
2//!
3//! Bridges the model-agnostic `PretrainLoop` (`pretrain.rs`) to the
4//! 370M Llama scaffold (`models/llama_370m.rs`) by wiring a real
5//! `TransformerTrainer` through the `StepFn` and `ValFn` traits.
6//!
7//! The loop drive replaces the `LinearDecaySynthetic` / `ScriptedVal`
8//! pair used for GATE-TRAIN-005/007/008 wiring verification (task #105)
9//! with a real forward + backward + optimizer step and a real held-out
10//! validation forward pass.
11//!
12//! Contract obligations discharged:
13//! - INV-ARCH-370M-001 (param count in [366M, 374M]) via `debug_assert_eq!`
14//! - INV-TRAIN-001 (per-step metrics — 6 fields via PretrainLoop)
15//! - INV-TRAIN-007 (no NaN/Inf — the loop aborts on first non-finite)
16//!
17//! Deferred (task #111 follow-ups):
18//! - Real grad_norm (currently reports a placeholder; needs
19//!   TransformerTrainer extension to surface pre-clip norm)
20//! - INV-TRAIN-003 (real optimizer-state sha256 over AdamW m/v/t buffers)
21
22use crate::models::llama_370m::Llama370MConfig;
23use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
24use crate::train::transformer_trainer::{LMBatch, TransformerTrainConfig, TransformerTrainer};
25use crate::transformer::{ModelArchitecture, Transformer, TransformerConfig};
26use crate::Tensor;
27use std::cell::RefCell;
28use std::collections::BTreeMap;
29use std::path::Path;
30use std::rc::Rc;
31
32/// Shared mutable ownership of the `TransformerTrainer` — both
33/// `RealStepFn` (training steps) and `RealValFn` (forward-only
34/// validation) clone this `Rc`.
35pub type SharedTrainer = Rc<RefCell<TransformerTrainer>>;
36
37/// Load tensors from an APR file as the read-half of `apr pretrain --init`.
38///
39/// Per `apr-pretrain-arch-polymorphic-v1` §init_load_semantics (PR #1473),
40/// the loader is REUSED, not reimplemented — this function is a thin wrapper
41/// over `aprender::format::converter::convert_report::load_model_tensors`,
42/// which is the same machinery `apr export` and `apr inspect` use. No
43/// duplicate APR parser; one source of truth.
44///
45/// Returns a map of `tensor_name -> (flat_f32_data, shape)`. The HF naming
46/// convention is preserved (e.g., `model.embed_tokens.weight`); reconciling
47/// against the trainer's parameter names is step 5f.3 (the population step).
48///
49/// Discharges from `apr-pretrain-arch-polymorphic-v1`:
50///   - §init_load_semantics invariant: "Loader is reused, not reimplemented"
51///   - FALSIFY-006 (init_loss < 6.0) at READ-COMPILE-BIND level: this
52///     function is the read half. Full FALSIFY-006 discharge requires
53///     5f.3 (population) + 5g (LIVE 500-step fine-tune).
54///
55/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.2.
56///
57/// # Errors
58///
59/// Returns Err if the APR file:
60/// - Does not exist (filesystem I/O error)
61/// - Has invalid magic bytes (not APR\\0 or APRN)
62/// - Has a corrupted tensor index
63/// - Contains tensors with unsupported dtype (non-F32)
64pub fn load_init_tensors_from_apr(
65    path: impl AsRef<Path>,
66) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String> {
67    let path_ref = path.as_ref();
68    aprender::format::converter::load_model_tensors(path_ref).map_err(|e| {
69        format!(
70            "FALSIFY-APR-PRETRAIN-INIT-006: failed to load init tensors from APR file {}: {e}",
71            path_ref.display()
72        )
73    })
74}
75
76/// Reject an init `TransformerConfig` whose architecture family is incompatible
77/// with the pretrain target (decoder-only LM training).
78///
79/// Per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature
80/// (PR #1473), wrong-arch APR (e.g., a CodeBERT/RoBERTa encoder model)
81/// MUST be FAIL-FAST not silent-truncate. Without this gate, an operator
82/// who points `--init` at e.g. `microsoft/codebert-base.apr` would silently
83/// load encoder weights into a decoder-shaped trainer, producing nonsense
84/// gradients that the divergence guard catches LATE (after multiple epochs).
85///
86/// Discharges FALSIFY-APR-PRETRAIN-ARCH-007 at PARTIAL_ALGORITHM_LEVEL.
87///
88/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
89///
90/// # Errors
91///
92/// Returns Err with a clear architecture-family-mismatch message when:
93/// - `cfg.architecture` is `ModelArchitecture::Encoder` (BERT/RoBERTa/CodeBERT)
94///
95/// Future expansion can add other family checks (e.g., reject hybrid SSM
96/// architectures whose forward pass differs from the standard decoder loop).
97pub fn validate_pretrain_init_arch_compatible(cfg: &TransformerConfig) -> Result<(), String> {
98    match cfg.architecture {
99        ModelArchitecture::Decoder => Ok(()),
100        ModelArchitecture::Encoder => Err(format!(
101            "FALSIFY-APR-PRETRAIN-ARCH-007: --init checkpoint has architecture=Encoder \
102             (e.g., BERT/RoBERTa/CodeBERT) but the pretrain trainer is decoder-only \
103             (Llama/Qwen-class causal LMs). Loading encoder weights into a decoder \
104             trainer would produce nonsense gradients. Architectural details: \
105             hidden_size={}, num_layers={}, vocab_size={}, hf_architecture={:?}",
106            cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size, cfg.hf_architecture
107        )),
108    }
109}
110
111/// Populate a `Transformer`'s parameters from an `init_tensors` BTreeMap.
112///
113/// For each parameter the `Transformer` exposes via `named_parameters()`, look
114/// up the same HF-naming key in `init_tensors` and replace the parameter's
115/// `Tensor` with `Tensor::from_vec(data, requires_grad=true)`. The parameter
116/// remains trainable (gradients still flow) — this is fine-tune init, not
117/// frozen-encoder.
118///
119/// Strictness rules:
120/// - **Every** model parameter must have a matching entry in `init_tensors`.
121///   Missing entries return Err naming the unmatched parameter; this catches
122///   the case where the init APR was extracted from a different architecture.
123/// - **Length** of each init entry must match the model parameter's length
124///   (computed from the model's existing tensor `len()`). Mismatch returns Err.
125/// - **Extra** entries in `init_tensors` are silently ignored. This handles
126///   `tie_word_embeddings`: a Qwen2.5 APR may publish a separate `lm_head.weight`
127///   tensor that the model omits when ties are enabled.
128///
129/// Discharges from `apr-pretrain-arch-polymorphic-v1` §init_load_semantics:
130/// - Population invariant: "Init tensors populate trainer parameters
131///   byte-equivalent to source"
132/// - FALSIFY-APR-PRETRAIN-INIT-007 (population step) at PARTIAL_ALGORITHM_LEVEL.
133///
134/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.3.
135///
136/// # Errors
137///
138/// Returns Err if any model parameter is missing from `init_tensors` or if
139/// any matched entry has a wrong length. The error message lists up to the
140/// first 5 problem parameters and the total count.
141pub fn populate_trainer_from_init_tensors(
142    transformer: &mut Transformer,
143    init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
144) -> Result<usize, String> {
145    let expected: Vec<(String, usize)> = transformer
146        .named_parameters()
147        .into_iter()
148        .map(|(name, t)| (name, t.len()))
149        .collect();
150    let mut populated = 0usize;
151    let mut errors: Vec<String> = Vec::new();
152
153    for (name, expected_len) in &expected {
154        match init_tensors.get(name) {
155            Some((data, _shape)) => {
156                if data.len() != *expected_len {
157                    errors.push(format!(
158                        "{name}: init length {} != trainer expected {expected_len}",
159                        data.len()
160                    ));
161                    continue;
162                }
163                let tensor = Tensor::from_vec(data.clone(), true);
164                if !transformer.set_named_parameter(name, tensor) {
165                    errors.push(format!(
166                        "{name}: set_named_parameter rejected the assignment"
167                    ));
168                    continue;
169                }
170                populated += 1;
171            }
172            None => {
173                errors.push(format!("{name}: not present in init APR tensors"));
174            }
175        }
176    }
177
178    if !errors.is_empty() {
179        let total = errors.len();
180        let head = errors.iter().take(5).cloned().collect::<Vec<_>>().join("; ");
181        return Err(format!(
182            "FALSIFY-APR-PRETRAIN-INIT-007: populate_trainer_from_init_tensors \
183             failed for {total} parameter(s); first {} of {total}: {head}",
184            errors.len().min(5)
185        ));
186    }
187
188    Ok(populated)
189}
190
191/// Build a `TransformerConfig` field-for-field from `Llama370MConfig::*`
192/// constants (the contract-frozen MODEL-2 370M architecture).
193pub fn llama_370m_transformer_config() -> TransformerConfig {
194    TransformerConfig {
195        hidden_size: Llama370MConfig::HIDDEN_DIM,
196        num_attention_heads: Llama370MConfig::NUM_HEADS,
197        num_kv_heads: Llama370MConfig::NUM_KV_HEADS,
198        intermediate_size: Llama370MConfig::INTERMEDIATE_DIM,
199        num_hidden_layers: Llama370MConfig::NUM_LAYERS,
200        vocab_size: Llama370MConfig::VOCAB_SIZE,
201        max_position_embeddings: Llama370MConfig::MAX_POSITION_EMBEDDINGS,
202        rms_norm_eps: Llama370MConfig::RMS_NORM_EPS,
203        rope_theta: Llama370MConfig::ROPE_THETA,
204        use_bias: false,
205        head_dim_override: None,
206        architecture: ModelArchitecture::Decoder,
207        hf_architecture: Some("LlamaForCausalLM".into()),
208        hf_model_type: Some("llama".into()),
209        tie_word_embeddings: true,
210    }
211}
212
213/// Polymorphic builder per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature.
214///
215/// Discharges FALSIFY-APR-PRETRAIN-ARCH-002 (init=None preserves Llama370M baseline)
216/// and FALSIFY-APR-PRETRAIN-ARCH-003 (init=Some passes through extracted config).
217///
218/// Behaviour:
219///   init = None  → return `llama_370m_transformer_config()`, the §24/§25
220///                  from-scratch baseline. NO regression.
221///   init = Some  → clone the caller-extracted `TransformerConfig` byte-for-byte.
222///                  No silent defaults, no field overrides.
223///
224/// The caller is responsible for actually reading the APR file and producing the
225/// `TransformerConfig` (typically via `TransformerConfig::from_apr_metadata` from
226/// `transformer::config`). Decoupling the dispatch from the file I/O keeps
227/// `aprender-train` free of `aprender-serve` (the APR loader) as a build dep.
228///
229/// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
230pub fn build_transformer_config(init: Option<&TransformerConfig>) -> TransformerConfig {
231    match init {
232        None => llama_370m_transformer_config(),
233        Some(cfg) => cfg.clone(),
234    }
235}
236
237/// Build a `TransformerTrainConfig` with MODEL-2 v2-remedy defaults
238/// (LR=5e-5, AdamW defaults, fp32, seed=42 set by caller).
239pub fn llama_370m_train_config(lr: f32, seq_length: usize, seed: u64) -> TransformerTrainConfig {
240    let model_cfg = llama_370m_transformer_config();
241    let mut cfg = TransformerTrainConfig::new(model_cfg);
242    cfg.lr = lr;
243    cfg.max_seq_len = seq_length;
244    cfg.seed = seed;
245    cfg
246}
247
248/// `StepFn` impl that pulls one `LMBatch` from an owned iterator and
249/// runs a real forward + backward + AdamW step through the shared
250/// `TransformerTrainer`.
251pub struct RealStepFn {
252    trainer: SharedTrainer,
253    batches: Box<dyn Iterator<Item = LMBatch>>,
254}
255
256impl RealStepFn {
257    pub fn new(trainer: SharedTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
258        Self { trainer, batches }
259    }
260}
261
262impl StepFn for RealStepFn {
263    fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
264        // Pull one batch; if the shard stream is exhausted before the
265        // loop plans to stop, emit a tiny finite placeholder so
266        // GATE-TRAIN-007 (NaN/Inf guard) does not mis-fire — the
267        // divergence guard (GATE-TRAIN-005) will correctly not abort
268        // on a flat tail.
269        let Some(batch) = self.batches.next() else {
270            return (1.0, 1.0);
271        };
272        let mut trainer = self.trainer.borrow_mut();
273        let loss = trainer.train_batch(&batch);
274        // TODO(task #111 follow-up): expose AdamW pre-clip grad norm.
275        // Placeholder = 1.0 keeps INV-TRAIN-007 satisfied (finite) and
276        // INV-TRAIN-008 satisfied (≥ 0); the real grad norm is a
277        // downstream ticket that needs TransformerTrainer extension.
278        let grad_norm = 1.0_f32;
279        (loss, grad_norm)
280    }
281
282    /// INV-TRAIN-003 discharge: hash the real AdamW (t, m, v) buffers.
283    fn optimizer_state_sha256(&self) -> Option<String> {
284        Some(self.trainer.borrow().optimizer_state_sha256())
285    }
286}
287
288/// `ValFn` impl that runs forward-only across a pre-loaded set of
289/// held-out batches and returns mean cross-entropy loss.
290pub struct RealValFn {
291    trainer: SharedTrainer,
292    held_out: Vec<LMBatch>,
293}
294
295impl RealValFn {
296    pub fn new(trainer: SharedTrainer, held_out: Vec<LMBatch>) -> Self {
297        Self { trainer, held_out }
298    }
299}
300
301impl ValFn for RealValFn {
302    fn validate(&mut self, _epoch: usize) -> f32 {
303        if self.held_out.is_empty() {
304            return f32::NAN;
305        }
306        let trainer = self.trainer.borrow();
307        let mut total_loss = 0.0_f32;
308        let mut total_items = 0_usize;
309        for batch in &self.held_out {
310            for i in 0..batch.batch_size {
311                let Some(inp) = batch.get_input(i) else {
312                    continue;
313                };
314                let Some(tgt) = batch.get_target(i) else {
315                    continue;
316                };
317                let (loss_val, _loss_tensor, _logits) = trainer.forward_single(inp, tgt);
318                total_loss += loss_val;
319                total_items += 1;
320            }
321        }
322        if total_items == 0 {
323            f32::NAN
324        } else {
325            total_loss / total_items as f32
326        }
327    }
328}
329
330/// `CheckpointFn` impl that writes the 370M Llama weights to
331/// `artifact.checkpoint_path` in APR format (task #111 step 7).
332///
333/// Holds the `SharedTrainer` alongside `RealStepFn` / `RealValFn` so
334/// the three hooks see the same in-memory weights.
335pub struct AprCheckpointFn {
336    trainer: SharedTrainer,
337    model_name: String,
338    architecture: String,
339}
340
341impl AprCheckpointFn {
342    pub fn new(
343        trainer: SharedTrainer,
344        model_name: impl Into<String>,
345        architecture: impl Into<String>,
346    ) -> Self {
347        Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
348    }
349}
350
351impl CheckpointFn for AprCheckpointFn {
352    fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
353        let trainer = self.trainer.borrow();
354        trainer
355            .save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
356            .map_err(|e| format!("save_apr failed: {e}"))
357    }
358}
359
360/// Shared-ownership helper so the CLI can hand the same trainer to
361/// both `RealStepFn` and `RealValFn`.
362pub fn build_shared_trainer(lr: f32, seq_length: usize, seed: u64) -> SharedTrainer {
363    let cfg = llama_370m_train_config(lr, seq_length, seed);
364    let trainer = TransformerTrainer::new(cfg);
365    // INV-ARCH-370M-001: verify parameter count lands in the 370M ± 1%
366    // band. This is a debug_assert so release builds do not pay for
367    // the full parameter walk, but dev builds catch drift the instant
368    // any Llama370MConfig constant changes.
369    #[cfg(debug_assertions)]
370    {
371        let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
372        debug_assert!(
373            (366_000_000..=374_000_000).contains(&param_count),
374            "INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
375        );
376    }
377    Rc::new(RefCell::new(trainer))
378}
379
380/// Polymorphic trainer builder for `apr pretrain --init` per
381/// `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature +
382/// §init_load_semantics (PR #1473).
383///
384/// Composes the §50.4 step-5f machinery into a single CLI-callable entry:
385///   - 5c: `build_transformer_config(init_arch)` — polymorphic dispatch
386///   - 5f.1: `validate_pretrain_init_arch_compatible(init_arch)` — encoder rejection
387///   - 5f.2: `load_init_tensors_from_apr(path)` — read APR weights
388///   - 5f.3: `populate_trainer_from_init_tensors(trainer, &tensors)` — populate
389///
390/// Behaviour:
391///   init = None  → identical to `build_shared_trainer` (Llama370M from-scratch
392///                  baseline; INV-ARCH-370M-001 enforced).
393///   init = Some  → builds a trainer with the EXTRACTED arch, validates the
394///                  family, loads tensors from the APR file, populates them.
395///                  INV-ARCH-370M-001 is NOT enforced (the arch is whatever the
396///                  init APR has, e.g. 0.5B / 1.5B / 7B).
397///
398/// Spec: SPEC-SHIP-TWO-001 §52.4 (step 5f.4 wireup).
399///
400/// # Errors
401///
402/// Returns Err when:
403/// - `init_arch` is `Some` with `architecture = Encoder` (FALSIFY-APR-PRETRAIN-ARCH-007)
404/// - `load_init_tensors_from_apr` fails (FALSIFY-APR-PRETRAIN-INIT-006)
405/// - `populate_trainer_from_init_tensors` fails (FALSIFY-APR-PRETRAIN-INIT-007)
406pub fn build_shared_trainer_with_init(
407    lr: f32,
408    seq_length: usize,
409    seed: u64,
410    init_arch: Option<&TransformerConfig>,
411    init_path: Option<&Path>,
412) -> Result<SharedTrainer, String> {
413    if init_arch.is_some() != init_path.is_some() {
414        return Err(format!(
415            "build_shared_trainer_with_init: init_arch and init_path must both be Some \
416             or both None (caller bug; init_arch.is_some()={}, init_path.is_some()={})",
417            init_arch.is_some(),
418            init_path.is_some()
419        ));
420    }
421
422    if let Some(cfg) = init_arch {
423        validate_pretrain_init_arch_compatible(cfg)?;
424    }
425
426    let model_cfg = build_transformer_config(init_arch);
427    let mut train_cfg = TransformerTrainConfig::new(model_cfg);
428    train_cfg.lr = lr;
429    train_cfg.max_seq_len = seq_length;
430    train_cfg.seed = seed;
431    let mut trainer = TransformerTrainer::new(train_cfg);
432
433    // Note: INV-ARCH-370M-001 (param-count band check) lives in
434    // `build_shared_trainer` (the from-scratch CLI path). The polymorphic
435    // builder is shape-agnostic by design — `build_transformer_config(init)`
436    // returns whatever the init APR has (0.5B, 1.5B, 7B, etc), so a single
437    // hardcoded band check would fire-fail on every non-Llama370M init.
438
439    if let Some(path) = init_path {
440        let tensors = load_init_tensors_from_apr(path)?;
441        populate_trainer_from_init_tensors(trainer.model_mut(), &tensors)?;
442    }
443
444    Ok(Rc::new(RefCell::new(trainer)))
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use crate::train::transformer_trainer::LMBatch;
451
452    /// FALSIFY-APR-PRETRAIN-INIT-006 (read-half) — load_init_tensors_from_apr
453    /// returns Err with a clear message naming the falsifier when the path
454    /// does not exist.
455    ///
456    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.2.
457    #[test]
458    fn load_init_tensors_missing_file_errors_with_falsifier_id() {
459        let tmp = tempfile::TempDir::new().expect("tempdir");
460        let missing = tmp.path().join("does-not-exist.apr");
461        let err = load_init_tensors_from_apr(&missing)
462            .expect_err("missing init APR file MUST fail-fast");
463        assert!(
464            err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
465            "error must cite falsifier id (auditability): {err}"
466        );
467        assert!(
468            err.contains("does-not-exist.apr"),
469            "error must name the missing path (operator-experience): {err}"
470        );
471    }
472
473    /// FALSIFY-APR-PRETRAIN-INIT-006 (read-half) — function exists with the
474    /// right signature: `Path -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>>`.
475    /// Discharges the COMPILE-BIND level claim. Live empirical correctness
476    /// requires step 5g (operator-runnable LIVE fine-tune).
477    ///
478    /// Drift-prevention: this test catches a future refactor that changes
479    /// the return type or signature, which would break the §50.4 step 5f.3
480    /// follow-up that reconciles the BTreeMap against trainer parameters.
481    #[test]
482    fn load_init_tensors_signature_compile_bind() {
483        // Verify the function signature compile-binds: takes a Path-like,
484        // returns the right Result type. This is a compile-time check —
485        // if the signature drifts, this test stops compiling.
486        fn _check_signature<F>(_f: F)
487        where
488            F: Fn(
489                &Path,
490            )
491                -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String>,
492        {
493        }
494        _check_signature(|p| load_init_tensors_from_apr(p));
495    }
496
497    #[test]
498    fn transformer_config_matches_llama_370m_constants() {
499        let cfg = llama_370m_transformer_config();
500        assert_eq!(cfg.hidden_size, Llama370MConfig::HIDDEN_DIM);
501        assert_eq!(cfg.num_hidden_layers, Llama370MConfig::NUM_LAYERS);
502        assert_eq!(cfg.num_attention_heads, Llama370MConfig::NUM_HEADS);
503        assert_eq!(cfg.num_kv_heads, Llama370MConfig::NUM_KV_HEADS);
504        assert_eq!(cfg.intermediate_size, Llama370MConfig::INTERMEDIATE_DIM);
505        assert_eq!(cfg.vocab_size, Llama370MConfig::VOCAB_SIZE);
506        assert!((cfg.rope_theta - Llama370MConfig::ROPE_THETA).abs() < f32::EPSILON);
507        assert!((cfg.rms_norm_eps - Llama370MConfig::RMS_NORM_EPS).abs() < f32::EPSILON);
508        assert!(!cfg.use_bias, "INV-ARCH-370M-008: no bias");
509        assert!(cfg.tie_word_embeddings, "INV-ARCH-370M-004: tied embeddings");
510    }
511
512    /// FALSIFY-APR-PRETRAIN-ARCH-002 — `build_transformer_config(None)` returns
513    /// the Llama370M baseline byte-for-byte. Falsifies regression in the §24/§25
514    /// from-scratch path when the polymorphic dispatch was added.
515    ///
516    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
517    #[test]
518    fn build_transformer_config_no_init_matches_llama370m() {
519        let baseline = llama_370m_transformer_config();
520        let result = build_transformer_config(None);
521        assert_eq!(result.hidden_size, baseline.hidden_size);
522        assert_eq!(result.num_attention_heads, baseline.num_attention_heads);
523        assert_eq!(result.num_kv_heads, baseline.num_kv_heads);
524        assert_eq!(result.intermediate_size, baseline.intermediate_size);
525        assert_eq!(result.num_hidden_layers, baseline.num_hidden_layers);
526        assert_eq!(result.vocab_size, baseline.vocab_size);
527        assert_eq!(
528            result.max_position_embeddings,
529            baseline.max_position_embeddings
530        );
531        assert!((result.rms_norm_eps - baseline.rms_norm_eps).abs() < f32::EPSILON);
532        assert!((result.rope_theta - baseline.rope_theta).abs() < f32::EPSILON);
533        assert_eq!(result.use_bias, baseline.use_bias);
534        assert_eq!(result.tie_word_embeddings, baseline.tie_word_embeddings);
535        assert_eq!(result.architecture, baseline.architecture);
536        assert_eq!(result.hf_architecture, baseline.hf_architecture);
537        assert_eq!(result.hf_model_type, baseline.hf_model_type);
538    }
539
540    /// FALSIFY-APR-PRETRAIN-ARCH-003 — `build_transformer_config(Some(cfg))`
541    /// passes through the caller-provided config byte-for-byte. No silent
542    /// defaults, no field overrides. Tests with Qwen2.5-Coder-0.5B shape
543    /// because that is the §49 fine-tune target.
544    ///
545    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c.
546    #[test]
547    fn build_transformer_config_qwen_init_matches_input() {
548        let qwen = TransformerConfig::qwen2_0_5b();
549        let result = build_transformer_config(Some(&qwen));
550        assert_eq!(result.hidden_size, qwen.hidden_size, "hidden_size");
551        assert_eq!(
552            result.num_attention_heads, qwen.num_attention_heads,
553            "num_attention_heads"
554        );
555        assert_eq!(result.num_kv_heads, qwen.num_kv_heads, "num_kv_heads");
556        assert_eq!(
557            result.intermediate_size, qwen.intermediate_size,
558            "intermediate_size"
559        );
560        assert_eq!(
561            result.num_hidden_layers, qwen.num_hidden_layers,
562            "num_hidden_layers"
563        );
564        assert_eq!(result.vocab_size, qwen.vocab_size, "vocab_size");
565        assert_eq!(
566            result.max_position_embeddings, qwen.max_position_embeddings,
567            "max_position_embeddings"
568        );
569        assert_eq!(result.use_bias, qwen.use_bias, "use_bias");
570        assert_eq!(
571            result.tie_word_embeddings, qwen.tie_word_embeddings,
572            "tie_word_embeddings"
573        );
574        assert_eq!(result.architecture, qwen.architecture, "architecture");
575        // GQA-7:1 ratio preserved (Qwen2.5-0.5B: 14 / 2 = 7)
576        assert_eq!(
577            result.num_attention_heads / result.num_kv_heads,
578            7,
579            "GQA ratio must preserve as 7:1 (Qwen2.5-0.5B canonical)"
580        );
581    }
582
583    /// Drift-prevention: dispatch is mutually exclusive — None and Some
584    /// produce different configs (otherwise the polymorphic builder is
585    /// vacuous). Catches a future refactor that accidentally always
586    /// returns Llama370M regardless of init.
587    ///
588    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5c — drift prevention.
589    #[test]
590    fn build_transformer_config_dispatch_mutually_exclusive() {
591        let qwen = TransformerConfig::qwen2_0_5b();
592        let none_result = build_transformer_config(None);
593        let some_result = build_transformer_config(Some(&qwen));
594        // The two outputs MUST differ, otherwise the dispatch is broken.
595        assert_ne!(
596            none_result.hidden_size, some_result.hidden_size,
597            "dispatch must differentiate None vs Some — Llama370M hidden=1024 vs Qwen=896"
598        );
599        assert_ne!(
600            none_result.vocab_size, some_result.vocab_size,
601            "dispatch must differentiate None vs Some — Llama370M vocab=50257 vs Qwen=151936"
602        );
603    }
604
605    /// FALSIFY-APR-PRETRAIN-ARCH-007 (decoder branch) — `validate_pretrain_init_arch_compatible`
606    /// returns Ok for a decoder-family config.
607    ///
608    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
609    #[test]
610    fn validate_pretrain_init_arch_accepts_decoder() {
611        let qwen = TransformerConfig::qwen2_0_5b();
612        assert_eq!(qwen.architecture, ModelArchitecture::Decoder);
613        validate_pretrain_init_arch_compatible(&qwen)
614            .expect("decoder-family config (Qwen2.5-0.5B) MUST pass arch-compat gate");
615    }
616
617    /// FALSIFY-APR-PRETRAIN-ARCH-007 (encoder branch) — load-bearing test.
618    /// `validate_pretrain_init_arch_compatible` returns Err naming the
619    /// architecture-family mismatch when given an encoder config (e.g.,
620    /// CodeBERT). Without this gate, the decoder trainer would silently
621    /// build with encoder weights producing nonsense gradients.
622    ///
623    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5f.1.
624    #[test]
625    fn validate_pretrain_init_arch_rejects_encoder() {
626        // Construct a minimal encoder config (CodeBERT-shaped).
627        let bert = TransformerConfig {
628            hidden_size: 768,
629            num_attention_heads: 12,
630            num_kv_heads: 12,
631            intermediate_size: 3072,
632            num_hidden_layers: 12,
633            vocab_size: 50265,
634            max_position_embeddings: 514,
635            rms_norm_eps: 1e-12,
636            rope_theta: 10_000.0,
637            use_bias: true,
638            head_dim_override: None,
639            architecture: ModelArchitecture::Encoder,
640            hf_architecture: Some("RobertaModel".to_string()),
641            hf_model_type: Some("roberta".to_string()),
642            tie_word_embeddings: false,
643        };
644        let err = validate_pretrain_init_arch_compatible(&bert).expect_err(
645            "encoder-family config (CodeBERT/RoBERTa) MUST fail arch-compat gate — \
646             silent acceptance would corrupt §49 fine-tune trajectory before any \
647             FALSIFY-006 check could measure it",
648        );
649        assert!(
650            err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
651            "error must cite falsifier id: {err}"
652        );
653        assert!(
654            err.contains("Encoder"),
655            "error must name the architecture family: {err}"
656        );
657        assert!(
658            err.contains("decoder-only"),
659            "error must explain why this is wrong (decoder trainer): {err}"
660        );
661        assert!(
662            err.contains("RobertaModel"),
663            "error must name the offending hf_architecture: {err}"
664        );
665    }
666
667    /// Drift-prevention: validate_pretrain_init_arch_compatible's behavior on
668    /// the from-scratch baseline (Llama370M) — must Ok. Catches a future
669    /// refactor that accidentally over-rejects decoder configs.
670    #[test]
671    fn validate_pretrain_init_arch_accepts_llama370m_baseline() {
672        let llama = llama_370m_transformer_config();
673        assert_eq!(
674            llama.architecture,
675            ModelArchitecture::Decoder,
676            "Llama370M baseline MUST be Decoder (regression-free)"
677        );
678        validate_pretrain_init_arch_compatible(&llama)
679            .expect("Llama370M baseline (Decoder) MUST pass arch-compat gate");
680    }
681
682    #[test]
683    fn real_step_fn_exhausted_iterator_returns_finite_placeholder() {
684        // Empty iterator means no real batches; we must still return
685        // finite values so the loop's non-divergence + NaN guards see
686        // sane data instead of surprising NaN.
687        //
688        // Construct a minimal trainer WITHOUT running `build_shared_trainer`
689        // because that takes ~5 GB of parameter allocation for 370M —
690        // too expensive for a unit test. Use a tiny synthetic config.
691        let mut tiny = TransformerConfig::llama2_7b();
692        tiny.hidden_size = 64;
693        tiny.num_attention_heads = 4;
694        tiny.num_kv_heads = 4;
695        tiny.num_hidden_layers = 2;
696        tiny.intermediate_size = 128;
697        tiny.vocab_size = 256;
698        let cfg = TransformerTrainConfig::new(tiny);
699        let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
700        let empty_iter: Box<dyn Iterator<Item = LMBatch>> = Box::new(std::iter::empty::<LMBatch>());
701        let mut step = RealStepFn::new(trainer, empty_iter);
702        let (loss, grad_norm) = step.step(0, 1.0e-4, 128);
703        assert!(loss.is_finite(), "exhausted iter must return finite loss");
704        assert!(grad_norm.is_finite(), "grad_norm must be finite");
705        assert!(grad_norm >= 0.0, "INV-TRAIN-008: grad_norm non-negative");
706    }
707
708    #[test]
709    fn real_val_fn_empty_held_out_returns_nan() {
710        let mut tiny = TransformerConfig::llama2_7b();
711        tiny.hidden_size = 64;
712        tiny.num_attention_heads = 4;
713        tiny.num_kv_heads = 4;
714        tiny.num_hidden_layers = 2;
715        tiny.intermediate_size = 128;
716        tiny.vocab_size = 256;
717        let cfg = TransformerTrainConfig::new(tiny);
718        let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
719        let mut val = RealValFn::new(trainer, Vec::new());
720        let loss = val.validate(0);
721        assert!(loss.is_nan(), "empty held_out must surface as NaN to the guard");
722    }
723
724    /// Build a tiny Transformer suitable for unit testing the populate path.
725    /// Uses GQA-1:1 (kv=q) shape — the populate function is shape-agnostic so
726    /// the simpler ratio is fine here.
727    fn tiny_test_transformer() -> Transformer {
728        let mut tiny = TransformerConfig::llama2_7b();
729        tiny.hidden_size = 32;
730        tiny.num_attention_heads = 2;
731        tiny.num_kv_heads = 2;
732        tiny.num_hidden_layers = 2;
733        tiny.intermediate_size = 64;
734        tiny.vocab_size = 16;
735        Transformer::new(&tiny)
736    }
737
738    /// Build a `BTreeMap<String, (Vec<f32>, Vec<usize>)>` from a Transformer's
739    /// `named_parameters()` snapshot. Each tensor is a deterministic ramp
740    /// (i as f32 * 0.001) so populate is byte-identifiable post-set.
741    fn tensors_map_from_transformer(
742        transformer: &Transformer,
743    ) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
744        let mut map = BTreeMap::new();
745        for (name, t) in transformer.named_parameters() {
746            let len = t.len();
747            let data: Vec<f32> = (0..len).map(|i| i as f32 * 0.001).collect();
748            map.insert(name, (data, vec![len]));
749        }
750        map
751    }
752
753    /// Happy path — every model parameter has a matching init entry of correct
754    /// length; populate succeeds and the count matches `named_parameters().len()`.
755    #[test]
756    fn populate_trainer_from_init_tensors_happy_path() {
757        let mut transformer = tiny_test_transformer();
758        let init_tensors = tensors_map_from_transformer(&transformer);
759        let expected_count = transformer.named_parameters().len();
760        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
761        assert!(
762            result.is_ok(),
763            "happy-path populate must succeed: {result:?}"
764        );
765        assert_eq!(
766            result.unwrap(),
767            expected_count,
768            "populated count must equal named_parameters().len()"
769        );
770    }
771
772    /// Drift-prevention: extra entries in `init_tensors` that the model does
773    /// NOT expose are silently ignored. This handles tied-embeddings: a Qwen
774    /// APR may publish a separate `lm_head.weight` that the trainer's tied
775    /// model omits.
776    #[test]
777    fn populate_trainer_from_init_tensors_extra_entries_silently_ignored() {
778        let mut transformer = tiny_test_transformer();
779        let mut init_tensors = tensors_map_from_transformer(&transformer);
780        // Inject a fictitious extra parameter that the model does not have.
781        init_tensors.insert(
782            "model.layers.999.fictitious.weight".to_string(),
783            (vec![0.0; 4], vec![4]),
784        );
785        let expected_count = transformer.named_parameters().len();
786        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
787        assert!(
788            result.is_ok(),
789            "extra init entries must NOT cause Err: {result:?}"
790        );
791        assert_eq!(result.unwrap(), expected_count);
792    }
793
794    /// FALSIFY-APR-PRETRAIN-INIT-007 (length mismatch) — when an init tensor
795    /// has the wrong flat length for a known parameter, populate MUST Err
796    /// with the FALSIFIER ID and a per-parameter diagnostic line.
797    #[test]
798    fn populate_trainer_from_init_tensors_rejects_length_mismatch() {
799        let mut transformer = tiny_test_transformer();
800        let mut init_tensors = tensors_map_from_transformer(&transformer);
801        // Corrupt one entry's length to trigger the mismatch path.
802        let any_name = transformer.named_parameters()[0].0.clone();
803        init_tensors.insert(any_name.clone(), (vec![0.0; 7], vec![7]));
804        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
805        assert!(
806            result.is_err(),
807            "length-mismatch must Err, not silently truncate"
808        );
809        let err = result.unwrap_err();
810        assert!(
811            err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
812            "error must cite falsifier id; got: {err}"
813        );
814        assert!(
815            err.contains(&any_name),
816            "error must name the offending parameter; got: {err}"
817        );
818        assert!(
819            err.contains("init length 7"),
820            "error must report the actual init length; got: {err}"
821        );
822    }
823
824    /// FALSIFY-APR-PRETRAIN-INIT-007 (missing-required) — when a model
825    /// parameter has NO corresponding entry in `init_tensors`, populate MUST
826    /// Err with FALSIFIER ID and a "not present in init APR tensors"
827    /// per-parameter diagnostic. This catches the architecture-mismatch
828    /// class where init was extracted from a different model family.
829    #[test]
830    fn populate_trainer_from_init_tensors_rejects_missing_required_param() {
831        let mut transformer = tiny_test_transformer();
832        let mut init_tensors = tensors_map_from_transformer(&transformer);
833        // Drop one entry to trigger the missing-required path.
834        let any_name = transformer.named_parameters()[0].0.clone();
835        init_tensors.remove(&any_name);
836        let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
837        assert!(
838            result.is_err(),
839            "missing-required must Err, not silently leave random init"
840        );
841        let err = result.unwrap_err();
842        assert!(
843            err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
844            "error must cite falsifier id; got: {err}"
845        );
846        assert!(
847            err.contains(&any_name),
848            "error must name the missing parameter; got: {err}"
849        );
850        assert!(
851            err.contains("not present in init APR"),
852            "error must say what was missing; got: {err}"
853        );
854    }
855
856    /// `build_shared_trainer_with_init(None, None)` returns a trainer with
857    /// the §24/§25 from-scratch Llama370M architecture (regression-free
858    /// dispatch). Asserts the baseline shape via the (hidden, vocab) tuple
859    /// rather than param count to avoid the stale INV-ARCH-370M-001 band
860    /// check in `build_shared_trainer` (a defect outside §50.4 scope —
861    /// param_count=322M vs assert range [366M, 374M]; tracked for follow-up).
862    #[test]
863    fn build_shared_trainer_with_init_none_uses_llama370m_shape() {
864        let trainer = build_shared_trainer_with_init(1.0e-4, 128, 42, None, None)
865            .expect("None case must succeed");
866        let model = trainer.borrow();
867        // The baseline polymorphic dispatch produces a Llama370M-shaped model.
868        // Embedding shape `vocab × hidden` is the cleanest non-stale check.
869        let embed_len = model.model().named_parameters()[0].1.len();
870        let expected_embed_len =
871            Llama370MConfig::VOCAB_SIZE * Llama370MConfig::HIDDEN_DIM;
872        assert_eq!(
873            embed_len, expected_embed_len,
874            "init=None must produce Llama370M-shaped embedding (vocab={} × hidden={})",
875            Llama370MConfig::VOCAB_SIZE,
876            Llama370MConfig::HIDDEN_DIM
877        );
878    }
879
880    /// `build_shared_trainer_with_init(Some, None)` and the inverse must
881    /// fail-fast — both args are paired and either both Some or both None.
882    /// Drift-prevention: catches a future caller that forgets to pass one.
883    #[test]
884    fn build_shared_trainer_with_init_rejects_unpaired_args() {
885        // arch Some, path None
886        let cfg = TransformerConfig::qwen2_0_5b();
887        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), None);
888        assert!(
889            result.is_err(),
890            "unpaired (arch=Some, path=None) must Err"
891        );
892        // arch None, path Some
893        let dummy_path = std::path::PathBuf::from("/dev/null");
894        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, None, Some(&dummy_path));
895        assert!(
896            result.is_err(),
897            "unpaired (arch=None, path=Some) must Err"
898        );
899    }
900
901    /// `build_shared_trainer_with_init(Some(encoder), Some(path))` rejects
902    /// the encoder family BEFORE attempting tensor load. Drift-prevention for
903    /// FALSIFY-APR-PRETRAIN-ARCH-007 at the trainer-builder integration level.
904    #[test]
905    fn build_shared_trainer_with_init_rejects_encoder_family() {
906        let mut encoder_cfg = TransformerConfig::qwen2_0_5b();
907        encoder_cfg.architecture = ModelArchitecture::Encoder;
908        let dummy_path = std::path::PathBuf::from("/nonexistent/encoder.apr");
909        let result =
910            build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&encoder_cfg), Some(&dummy_path));
911        let err = match result {
912            Ok(_) => panic!("encoder family must be rejected before tensor load"),
913            Err(e) => e,
914        };
915        assert!(
916            err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
917            "error must cite falsifier id; got: {err}"
918        );
919    }
920
921    /// `build_shared_trainer_with_init(Some(decoder), Some(missing_path))`
922    /// proceeds past the family check and FAILS at tensor load with a
923    /// FALSIFY-006 error. Pins the failure ordering: arch validation first,
924    /// then tensor load.
925    #[test]
926    fn build_shared_trainer_with_init_decoder_family_proceeds_to_tensor_load() {
927        let cfg = TransformerConfig::qwen2_0_5b();
928        let dummy_path = std::path::PathBuf::from("/nonexistent/decoder.apr");
929        let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), Some(&dummy_path));
930        let err = match result {
931            Ok(_) => panic!("missing tensor path must Err"),
932            Err(e) => e,
933        };
934        assert!(
935            err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
936            "decoder family proceeds to tensor load; failure cites INIT-006 not ARCH-007; got: {err}"
937        );
938        assert!(
939            !err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
940            "decoder family must NOT trigger encoder-rejection; got: {err}"
941        );
942    }
943}