Skip to main content

apr_cli/commands/
pretrain.rs

1//! `apr pretrain` — pretraining loop driver for SHIP-TWO-001 MODEL-2.
2//!
3//! Wires `entrenar::train::pretrain::PretrainLoop` into the CLI. The
4//! loop shape is enforced by `contracts/training-loop-pretrain-v1.yaml`
5//! — specifically GATE-TRAIN-005 (divergence), GATE-TRAIN-007 (NaN),
6//! and GATE-TRAIN-008 (throughput range).
7//!
8//! For MODEL-2 specifically, the 370M model forward pass is still a
9//! scaffold (see `crates/aprender-train/src/models/llama_370m.rs`),
10//! so this command runs in **synthetic** mode by default: it drives
11//! the loop with a deterministic decreasing-loss step function so the
12//! contract gates are exercised end-to-end even before the 370M
13//! compute path is wired.
14
15use crate::error::{CliError, Result};
16use crate::output;
17use clap::ValueEnum;
18use colored::Colorize;
19use entrenar::models::llama_370m::{
20    assert_tokenizer_vocab_matches_model, assert_tokenizer_vocab_within_model_bound,
21    Llama370MConfig,
22};
23use entrenar::train::device::{resolve_device, Device};
24use entrenar::train::pretrain::{
25    CheckpointFn, LinearDecaySynthetic, PretrainAbort, PretrainConfig, PretrainLoop, RunStatus,
26    ScriptedVal, StepFn, TrainingRegime, ValFn,
27};
28use entrenar::train::pretrain_real::{
29    build_shared_trainer, build_shared_trainer_with_init, AprCheckpointFn, RealStepFn, RealValFn,
30};
31use entrenar::train::shard_reader::ShardBatchIter;
32use entrenar::train::transformer_trainer::LMBatch;
33use entrenar::transformer::TransformerConfig;
34use std::path::Path;
35
36/// Number of LMBatches pulled off the head of the shard stream and
37/// reserved as the held-out validation set.
38///
39/// 2026-04-26: bumped from 2 → 16 to reduce val_loss measurement
40/// noise on from-scratch runs. With batch=16 seq=512, the prior
41/// 2-batch held-out covered just 16,384 tokens — single-batch
42/// fluctuation was ~0.04 in val_loss, which is at the same scale
43/// as epoch-over-epoch improvement signal during early training.
44/// A 50K-step run early-stopped at epoch 5/24 even though
45/// train_loss was monotonically decreasing (10.01 → 9.54). With 16
46/// held-out batches (131K tokens), val_loss noise floor drops
47/// proportionally to ~0.01, restoring early-stop signal-to-noise.
48const HELD_OUT_BATCHES: usize = 16;
49
50/// Drift-prevention constant pinned by `apr-pretrain-arch-polymorphic-v1`
51/// v1.7.0 §FALSIFY-APR-PRETRAIN-INIT-CUDA-001.
52///
53/// Pre-§50.4-step-5f.5 (this constant's first incarnation, v1.4.0..v1.6.0):
54/// the fail-fast error returned when `--init <PATH>` AND `--device cuda`
55/// were combined and the CUDA wireup did not exist. The const was the
56/// drift-prevention surface — a unit test verified the citation, the
57/// "not yet wired" phrase, and the 5f.5 reference all appeared.
58///
59/// Post-5f.5 (this PR — `apr-pretrain-arch-polymorphic-v1` v1.7.0): the
60/// CUDA wireup landed via `entrenar::train::pretrain_real_cuda::
61/// build_shared_cuda_trainer_with_init` (symmetric to the CPU
62/// `build_shared_trainer_with_init`). The const is RETAINED but its
63/// payload is repurposed as a drift-prevention sentinel: if a future
64/// refactor accidentally re-introduces a fail-fast on the CUDA + --init
65/// path, the test that pins this string will fail-fast and surface the
66/// regression. The string itself is no longer emitted by any code path
67/// in `drive_real`; it survives only to anchor the contract obligation.
68pub(crate) const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG: &str =
69    "FALSIFY-APR-PRETRAIN-INIT-CUDA-001: --init is wired for --device cuda \
70     via build_shared_cuda_trainer_with_init (5f.5 SHIPPED); operator can pass \
71     --init <PATH> --device cuda for end-to-end GPU fine-tune dispatch.";
72
73/// CLI selector bound to training-loop-pretrain-v1 §hyperparameter_defaults.
74/// Atomically flips the `(regime, lr_max, warmup_steps, target_val_loss)`
75/// 4-tuple per INV-TRAIN-009. Explicit `--lr` / `--warmup-steps` /
76/// `--target-val-loss` still win over the table row.
77#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
78pub enum PretrainMode {
79    /// Post-divergence MODEL-1 remedy defaults (lr=5e-5, warmup=100, target=2.2).
80    Finetune,
81    /// 370M cold-start defaults (lr=3e-4, warmup=1000, target=3.0).
82    FromScratch,
83}
84
85/// Resolved HP tuple from the contract's `hyperparameter_defaults` table.
86/// Inputs are CLI-provided overrides (`None` means "inherit mode default").
87/// Output binds INV-TRAIN-009: regime ALWAYS matches `mode`, and any field
88/// the operator set explicitly passes through unchanged.
89#[derive(Clone, Debug, PartialEq)]
90pub(crate) struct ResolvedHp {
91    pub regime: TrainingRegime,
92    pub lr_max: f32,
93    pub warmup_steps: usize,
94    pub target_val_loss: f32,
95}
96
97/// SPEC §82 P0-H: derive APR checkpoint `general.name` and `architecture`
98/// metadata from the `--init` model's TransformerConfig. Without this, the
99/// trainer hardcoded `("llama-370m-pretrain", "LlamaForCausalLM")` even when
100/// fine-tuning a Qwen2 model — which silently produced GGUF exports that
101/// llama.cpp could not load because the 72 Qwen2 bias tensors (q_proj_bias,
102/// k_proj_bias, v_proj_bias per layer × 24 layers) leaked through the
103/// llama-family GGUF mapper as unrecognized passthrough names. The fix
104/// stamps `Qwen2ForCausalLM` so the qwen2 family mapper handles biases
105/// correctly.
106///
107/// Falls back to the pre-§82 defaults when `--init` is not provided (a
108/// from-scratch llama-370m pretrain).
109fn checkpoint_name_and_arch(init_arch: Option<&TransformerConfig>) -> (String, String) {
110    match init_arch {
111        Some(arch) => {
112            let hf_arch = arch
113                .hf_architecture
114                .clone()
115                .unwrap_or_else(|| "LlamaForCausalLM".to_string());
116            // Use the lowercase hf_model_type for the name suffix when
117            // available (e.g. "qwen2-pretrain"), else fall back to a
118            // generic name.
119            let name = arch
120                .hf_model_type
121                .as_deref()
122                .map_or_else(|| "model-pretrain".to_string(), |t| format!("{t}-pretrain"));
123            (name, hf_arch)
124        }
125        None => (
126            "llama-370m-pretrain".to_string(),
127            "LlamaForCausalLM".to_string(),
128        ),
129    }
130}
131
132/// SPEC §82 P1-A: Estimate transformer parameter count from arch dims.
133///
134/// Formula (decoder-only, tied or untied embedding):
135///   N ≈ vocab × hidden                              (embedding)
136///     + L × (4·hidden² + 3·hidden·intermediate)     (per-layer attn + ffn)
137///     + hidden                                       (final norm)
138///
139/// Embedding is counted once (assumes tied lm_head; for untied add a 2nd
140/// `vocab × hidden`). This is a coarse estimate suitable for Chinchilla
141/// scaling sanity checks, not a precise param report — for that, use
142/// `apr inspect --json | jq .parameters`.
143fn estimate_param_count(arch: &TransformerConfig) -> u64 {
144    let vocab = arch.vocab_size as u64;
145    let hidden = arch.hidden_size as u64;
146    let inter = arch.intermediate_size as u64;
147    let layers = arch.num_hidden_layers as u64;
148    let embed = vocab.saturating_mul(hidden);
149    let attn_per_layer = 4u64.saturating_mul(hidden).saturating_mul(hidden);
150    let ffn_per_layer = 3u64.saturating_mul(hidden).saturating_mul(inter);
151    let per_layer = attn_per_layer.saturating_add(ffn_per_layer);
152    let layer_total = layers.saturating_mul(per_layer);
153    embed.saturating_add(layer_total).saturating_add(hidden)
154}
155
156pub(crate) fn mode_defaults(
157    mode: PretrainMode,
158    vocab_size: u32,
159    lr_override: Option<f32>,
160    warmup_override: Option<usize>,
161    target_override: Option<f32>,
162) -> ResolvedHp {
163    let (regime, lr_def, warmup_def, target_def) = match mode {
164        PretrainMode::Finetune => (TrainingRegime::Finetune, 5.0e-5, 100, 2.2),
165        PretrainMode::FromScratch => (
166            TrainingRegime::FromScratch { vocab_size },
167            3.0e-4,
168            1000,
169            3.0,
170        ),
171    };
172    ResolvedHp {
173        regime,
174        lr_max: lr_override.unwrap_or(lr_def),
175        warmup_steps: warmup_override.unwrap_or(warmup_def),
176        target_val_loss: target_override.unwrap_or(target_def),
177    }
178}
179
180/// Execute `apr pretrain`.
181#[allow(clippy::too_many_arguments)]
182pub(crate) fn run(
183    dataset: &Path,
184    tokenizer: &Path,
185    run_dir: &Path,
186    mode: PretrainMode,
187    lr: Option<f32>,
188    num_steps: usize,
189    warmup_steps: Option<usize>,
190    batch_size: usize,
191    seq_length: usize,
192    steps_per_epoch: usize,
193    seed: u64,
194    target_val_loss: Option<f32>,
195    vocab_size: u32,
196    synthetic: bool,
197    device: &str,
198    init: Option<&Path>,
199    force_under_provisioned: bool,
200    val_shard: Option<&Path>,
201    json_output: bool,
202) -> Result<()> {
203    // Contract gpu-training-backend-v1 INV-GPUTRAIN-001 / GATE-GPUTRAIN-002:
204    // parse --device BEFORE any trainer allocation so an invalid spec
205    // or an explicit `cuda` on a CPU-only host fails fast with a clear
206    // diagnostic. Synthetic drive still honours --device (for parity
207    // with real compute) but the stub error surface is identical.
208    let resolved_device =
209        resolve_device(device).map_err(|e| CliError::ValidationFailed(e.to_string()))?;
210
211    // Contract apr-pretrain-from-init-v1 §init_load_semantics + §50.4 step 5f.4:
212    // when --init is present, (1) validate magic bytes, (2) extract
213    // TransformerConfig from the APR header metadata, (3) propagate the
214    // extracted arch through preflight + trainer construction.
215    // Per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature,
216    // missing or unreadable architecture metadata is FAIL-FAST not silent-fallback.
217    let init_arch: Option<TransformerConfig> = if let Some(init_path) = init {
218        validate_init_apr_path(init_path)?;
219        Some(
220            crate::commands::model_config::read_apr_architecture(init_path).ok_or_else(|| {
221                CliError::ValidationFailed(format!(
222                    "FALSIFY-APR-PRETRAIN-INIT-005: --init APR file at {} has missing or invalid \
223                     architecture metadata (hidden_size, num_heads, num_layers, vocab_size, etc). \
224                     Cannot extract TransformerConfig per apr-pretrain-arch-polymorphic-v1 \
225                     §arch_extraction_signature.",
226                    init_path.display()
227                ))
228            })?,
229        )
230    } else {
231        None
232    };
233
234    let hp = mode_defaults(mode, vocab_size, lr, warmup_steps, target_val_loss);
235
236    // SPEC §82 P1-A + SPEC §83 P0-J: Chinchilla compute-optimal gate
237    // (Hoffmann et al. 2022, arXiv:2203.15556). Compute-optimal pretraining
238    // requires train tokens D ≈ 20·N where N is the parameter count.
239    //
240    // P0-J upgrade (post-audit, 2026-05-16, audit Rec #2): D/N < 10× is
241    // now a HARD BLOCKER (fail-fast) unless `--force-under-provisioned`
242    // is passed. 10× ≤ D/N < 20× is a strong warning. Triggered only on
243    // `--init` runs where arch dims allow N estimation; from-scratch
244    // runs are exempt.
245    //
246    // Audit motivation: §82 P2-A's 0.04× ratio + repetitive token
247    // gibberish at val_loss=4.71 (Holtzman et al. 2019 degeneration)
248    // proved that 30 min of theoretical falsification saves 8h+ GPU.
249    // Contract: contracts/chinchilla-gate-v1.yaml.
250    if let Some(arch) = init_arch.as_ref() {
251        let n_params = estimate_param_count(arch);
252        let d_tokens = (num_steps as u64)
253            .saturating_mul(batch_size as u64)
254            .saturating_mul(seq_length as u64);
255        let ratio = d_tokens as f64 / n_params as f64;
256        let suggested_steps = if batch_size > 0 && seq_length > 0 {
257            (20 * n_params) / (batch_size as u64 * seq_length as u64)
258        } else {
259            0
260        };
261
262        if ratio < 10.0 && !force_under_provisioned {
263            return Err(CliError::ValidationFailed(format!(
264                "[P0-J] Chinchilla hard gate (chinchilla-gate-v1): \
265                 train tokens D = {} ({:.1}M) is {:.3}× param count N = {} ({:.1}M); \
266                 Chinchilla compute-optimal target is D ≈ 20·N (Hoffmann et al. 2022, arXiv:2203.15556). \
267                 Run REJECTED: D/N < 10× will produce mode collapse / repetitive degeneration \
268                 (Holtzman et al. 2019, arXiv:1904.09751). \
269                 Increase --num-steps to ~{} OR widen --dataset corpus OR reduce model size. \
270                 To bypass anyway (e.g. ablation studies, resumed runs), pass --force-under-provisioned.",
271                d_tokens,
272                d_tokens as f64 / 1e6,
273                ratio,
274                n_params,
275                n_params as f64 / 1e6,
276                suggested_steps,
277            )));
278        }
279
280        if ratio < 10.0 {
281            // Bypassed via --force-under-provisioned: emit a loud warning
282            // so the override is captured in the log.
283            eprintln!(
284                "[P0-J] Chinchilla gate BYPASSED via --force-under-provisioned: \
285                 D = {} ({:.1}M) is {:.3}× N = {} ({:.1}M). \
286                 Run will likely produce repetitive/degenerate output. \
287                 You explicitly opted in.",
288                d_tokens,
289                d_tokens as f64 / 1e6,
290                ratio,
291                n_params,
292                n_params as f64 / 1e6,
293            );
294        } else if ratio < 20.0 {
295            // 10× ≤ D/N < 20× — below compute-optimal but training will
296            // still progress meaningfully. Warning, not error.
297            eprintln!(
298                "[P1-A] Chinchilla gate WARNING: D = {} ({:.1}M) is {:.1}× N = {} ({:.1}M); \
299                 below compute-optimal 20·N target — model has room for more training. \
300                 Suggested --num-steps for 20·N: ~{}.",
301                d_tokens,
302                d_tokens as f64 / 1e6,
303                ratio,
304                n_params,
305                n_params as f64 / 1e6,
306                suggested_steps,
307            );
308        }
309    }
310
311    // Validation: GATE-TRAIN-003 requires target_val_loss > 0.
312    if hp.target_val_loss <= 0.0 {
313        return Err(CliError::ValidationFailed(format!(
314            "target_val_loss must be positive, got {}",
315            hp.target_val_loss
316        )));
317    }
318    if num_steps == 0 {
319        return Err(CliError::ValidationFailed(
320            "num_steps must be > 0".to_string(),
321        ));
322    }
323    if steps_per_epoch == 0 {
324        return Err(CliError::ValidationFailed(
325            "steps_per_epoch must be > 0".to_string(),
326        ));
327    }
328
329    let config = PretrainConfig {
330        dataset_path: dataset.to_path_buf(),
331        tokenizer_dir: tokenizer.to_path_buf(),
332        run_dir: run_dir.to_path_buf(),
333        lr_max: hp.lr_max,
334        lr_min: (hp.lr_max * 1.0e-2).max(1.0e-7),
335        warmup_steps: hp.warmup_steps,
336        total_steps: num_steps,
337        batch_size,
338        seq_length,
339        steps_per_epoch,
340        seed,
341        grad_clip: 1.0,
342        weight_decay: 0.01,
343        target_val_loss: hp.target_val_loss,
344        // Patience widened from 2 → 5 epochs for from-scratch runs (2026-04-26).
345        // Rationale: a 50K-step run early-stopped at epoch 5/24 even though
346        // train_loss was monotonically decreasing 10.01 → 9.54 (Δ=−0.47);
347        // val_loss noise on 16k-token val set (now 131k) had stdev ~0.04,
348        // same scale as epoch-over-epoch improvement signal during early
349        // training. 5 patience epochs gives the optimizer time to push past
350        // local plateaus without ending an obviously-still-converging run.
351        patience_epochs: 5,
352        // Minimum epochs before early-stop. Bumped 1 → 3 so the warmup
353        // window (1000 steps = 1 epoch at 1000 steps_per_epoch, or 0.5
354        // epoch at 2000 steps_per_epoch) plus 1-2 initial epochs of post-
355        // warmup learning are guaranteed to complete before any early-stop
356        // signal is honoured.
357        min_epochs_before_early_stop: 3,
358        regime: hp.regime,
359    };
360
361    if !json_output {
362        print_header(&config);
363        // GATE-GPUTRAIN-002 visibility: print the resolved Device so the
364        // operator can confirm which backend was selected. `auto` is the
365        // only spec that may silently fall back, and this print makes
366        // the fall-back visible at startup.
367        output::kv("  Device", resolved_device.to_string());
368        println!();
369    }
370
371    let status = if synthetic {
372        drive_synthetic(
373            config.clone(),
374            num_steps,
375            steps_per_epoch,
376            hp.target_val_loss,
377            json_output,
378        )?
379    } else {
380        drive_real(
381            config.clone(),
382            dataset,
383            hp.lr_max,
384            seq_length,
385            batch_size,
386            seed,
387            resolved_device,
388            json_output,
389            init_arch.as_ref(),
390            init,
391            val_shard,
392        )?
393    };
394
395    // Contract: non-OK terminal statuses map to non-zero exit codes so
396    // operators can recognize divergence / NaN from shell `$?`.
397    match status {
398        RunStatus::Aborted(abort) => Err(abort_to_err(&abort)),
399        RunStatus::Ok { .. } | RunStatus::EarlyStop { .. } => Ok(()),
400    }
401}
402
403/// Synthetic drive: deterministic linear-decay `StepFn` and a scripted
404/// val-loss sequence so the full gate surface (GATE-TRAIN-005/007/008)
405/// is exercised end-to-end with no corpus I/O.
406fn drive_synthetic(
407    config: PretrainConfig,
408    num_steps: usize,
409    steps_per_epoch: usize,
410    target_val_loss: f32,
411    json_output: bool,
412) -> Result<RunStatus> {
413    let step_fn = LinearDecaySynthetic {
414        start_loss: (target_val_loss * 2.0).max(1.5),
415        decay_per_step: (target_val_loss * 0.01).max(1.0e-4),
416        grad_norm: 0.8,
417    };
418    let num_epochs = num_steps.div_ceil(steps_per_epoch);
419    let mut sequence = Vec::with_capacity(num_epochs + 2);
420    let start_val = (target_val_loss * 1.8).max(3.0);
421    for i in 0..(num_epochs + 2) {
422        let t = i as f32 / (num_epochs.max(1) as f32);
423        sequence.push(target_val_loss + (start_val - target_val_loss) * (1.0 - t).max(0.0));
424    }
425    let val_fn = ScriptedVal { sequence };
426    // Synthetic drive has no real weights to checkpoint.
427    run_and_report(config, step_fn, val_fn, None, json_output)
428}
429
430/// Contract apr-pretrain-from-init-v1 §init_load_semantics + §init_error_semantics:
431/// validate `--init <PATH>` BEFORE any trainer allocation. Falsifies
432/// FALSIFY-APR-PRETRAIN-INIT-003 (missing-file) + -004 (invalid-magic).
433///
434/// Returns Ok on a valid APR file (existence + magic bytes verified).
435/// Architecture extraction + weight load are §50.4 step 5f.4 — the
436/// caller (`run()`) extracts the config via `model_config::read_apr_architecture`
437/// and passes both to `build_shared_trainer_with_init` per
438/// `apr-pretrain-arch-polymorphic-v1` §init_load_semantics.
439fn validate_init_apr_path(path: &Path) -> Result<()> {
440    let mut file = std::fs::File::open(path).map_err(|e| {
441        CliError::ValidationFailed(format!(
442            "FALSIFY-APR-PRETRAIN-INIT-003: --init path does not exist or is unreadable: {} ({e})",
443            path.display()
444        ))
445    })?;
446    let mut magic = [0u8; 4];
447    use std::io::Read;
448    file.read_exact(&mut magic).map_err(|e| {
449        CliError::ValidationFailed(format!(
450            "FALSIFY-APR-PRETRAIN-INIT-004: --init file too short to contain APR magic bytes: {} ({e})",
451            path.display()
452        ))
453    })?;
454    // APR magic bytes per `crates/aprender-core/src/format/kani_proofs.rs`:
455    //   APR\0 = [0x41, 0x50, 0x52, 0x00] (v2)
456    //   APRN  = [0x41, 0x50, 0x52, 0x4E] (v1)
457    const APR_MAGIC_V2: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
458    const APR_MAGIC_V1: [u8; 4] = [0x41, 0x50, 0x52, 0x4E];
459    if magic != APR_MAGIC_V2 && magic != APR_MAGIC_V1 {
460        return Err(CliError::ValidationFailed(format!(
461            "FALSIFY-APR-PRETRAIN-INIT-004: --init file is not a valid APR file (magic={:02X?}, expected {:02X?} or {:02X?}): {}",
462            magic, APR_MAGIC_V2, APR_MAGIC_V1, path.display()
463        )));
464    }
465    Ok(())
466}
467
468/// GATE-ARCH-370M-011 pre-flight: count the tokenizer's vocabulary entries
469/// from `vocab.json` and assert the count matches `target_vocab_size`
470/// before any trainer allocation.
471///
472/// Per `apr-pretrain-arch-polymorphic-v1` §qwen_tokenizer_vocab_compatibility
473/// (PR #1473), the target is now POLYMORPHIC — when `--init <PATH>` is set,
474/// the caller passes the extracted-arch's vocab_size (e.g., 151_936 for
475/// Qwen2.5-0.5B); otherwise `Llama370MConfig::VOCAB_SIZE` (50_257) for
476/// the §24/§25 from-scratch baseline.
477///
478/// Any mismatch aborts the dispatch with a clear error naming both values
479/// and the violated invariant — the N-09 OOB escape in `Embedding::forward`
480/// would otherwise silently corrupt training.
481///
482/// Discharges FALSIFY-APR-PRETRAIN-ARCH-005 (Qwen tokenizer passes with
483/// Qwen target) and FALSIFY-APR-PRETRAIN-ARCH-006 (Qwen tokenizer fails
484/// with Llama target).
485fn preflight_tokenizer_vocab_matches_target(
486    tokenizer_dir: &Path,
487    target_vocab_size: usize,
488    init_is_some: bool,
489) -> Result<()> {
490    let vocab_path = tokenizer_dir.join("vocab.json");
491    let vocab_json = std::fs::read_to_string(&vocab_path).map_err(|e| {
492        CliError::ValidationFailed(format!(
493            "GATE-ARCH-370M-011 pre-flight: cannot read {} ({e})",
494            vocab_path.display()
495        ))
496    })?;
497    let vocab: serde_json::Map<String, serde_json::Value> = serde_json::from_str(&vocab_json)
498        .map_err(|e| {
499            CliError::ValidationFailed(format!(
500                "GATE-ARCH-370M-011 pre-flight: {} is not a valid vocab.json: {e}",
501                vocab_path.display()
502            ))
503        })?;
504    // §55: when --init is set (polymorphic path with HF-distributed
505    // checkpoint), allow tokenizer_vocab ≤ model_vocab to admit Qwen-style
506    // reserved-slot vocabularies. When --init is absent (§24/§25 from-scratch
507    // baseline), enforce strict equality to preserve INV-ARCH-370M-006.
508    if init_is_some {
509        assert_tokenizer_vocab_within_model_bound(vocab.len(), target_vocab_size)
510            .map_err(CliError::ValidationFailed)
511    } else {
512        assert_tokenizer_vocab_matches_model(vocab.len(), target_vocab_size)
513            .map_err(CliError::ValidationFailed)
514    }
515}
516
517/// Real-corpus drive: build a shared 370M trainer (CPU or CUDA), split
518/// the shard stream head-off into a held-out validation set, and run a
519/// full forward + backward + AdamW step per training batch.
520///
521/// When `device.is_cuda()`, the `cuda` feature must be compiled in —
522/// otherwise this surfaces a clear error rather than silently falling
523/// back to CPU (GATE-GPUTRAIN-002, contract gpu-training-backend-v1).
524#[allow(clippy::too_many_arguments)]
525fn drive_real(
526    config: PretrainConfig,
527    dataset: &Path,
528    lr: f32,
529    seq_length: usize,
530    batch_size: usize,
531    seed: u64,
532    device: Device,
533    json_output: bool,
534    init_arch: Option<&TransformerConfig>,
535    init_path: Option<&Path>,
536    val_shard: Option<&Path>,
537) -> Result<RunStatus> {
538    // GATE-ARCH-370M-011 / INV-ARCH-370M-006 — refuse to dispatch a real
539    // training step when the tokenizer vocab_size and the model vocab_size
540    // disagree. The N-09 OOB escape guard in Embedding::forward masks the
541    // mismatch at runtime → silent garbage gradients otherwise. Synthetic
542    // drive skips this check because it never touches the real model.
543    // Per `apr-pretrain-arch-polymorphic-v1` §qwen_tokenizer_vocab_compatibility
544    // (§50.4 step 5d/5f.4): when --init is set, gate by the EXTRACTED arch's
545    // vocab_size; otherwise gate by the §24/§25 baseline Llama370MConfig::VOCAB_SIZE,
546    // preserving regression-free behavior (FALSIFY-002 + FALSIFY-005 + FALSIFY-006).
547    let target_vocab = init_arch
548        .map(|cfg| cfg.vocab_size)
549        .unwrap_or(Llama370MConfig::VOCAB_SIZE);
550    preflight_tokenizer_vocab_matches_target(
551        &config.tokenizer_dir,
552        target_vocab,
553        init_arch.is_some(),
554    )?;
555
556    // MVP: pad_id/eos_id both 0. All sequences are uniform length
557    // (seq_length + 1) so LMBatch::from_sequences takes the shared
558    // layout path and pad_id is never used for padding. The real
559    // tokenizer's special-token ids will plumb through in a follow-up.
560    //
561    // wrap_around=true: when the corpus shards are exhausted before
562    // --num-steps is reached, reset cursor to shard 0 and continue.
563    // This is standard ML-training behaviour (matches PyTorch /
564    // HuggingFace). Without it, an 18M-token corpus exhausts in ~2
565    // epochs of a 5K-step run with batch=16 seq=512, and the
566    // Cuda*StepFn falls back to placeholder loss `(1.0, 1.0)` — silently
567    // producing garbage gradients. See spec §22 (PR #1073) for the
568    // root-cause investigation.
569    let mut iter = ShardBatchIter::new(dataset, batch_size, seq_length, 0, 0)
570        .map_err(|e| {
571            CliError::ValidationFailed(format!(
572                "dataset shard iterator init failed: {e} (path={})",
573                dataset.display()
574            ))
575        })?
576        .with_wrap_around(true)
577        // SPEC §82 P2-B: surface data starvation. When the corpus cycles
578        // mid-run, emit a stderr line so operators can detect that the
579        // step budget exceeds the corpus capacity (per Chinchilla, train
580        // tokens D ≈ 20·N — if D is small, the corpus wraps repeatedly
581        // and the model memorizes instead of generalizing).
582        .with_warn_on_wrap_around(true);
583
584    // SPEC §84 P2-F (apr-pretrain-val-shard-v1): held-out val source.
585    //
586    // When --val-shard <DIR> is provided, drain HELD_OUT_BATCHES from a
587    // dedicated independent shard iterator over <DIR>; the training iter
588    // stays at offset 0 (no batch theft). This makes val_loss comparable
589    // across runs whose --dataset composition changes (the P2-C audit-
590    // falsified result was confounded by val sets drawn from different
591    // corpus distributions — see evidence/p2c-2026-05-17/findings.md).
592    //
593    // When --val-shard is None, the historical "first N batches of
594    // --dataset" behaviour is preserved.
595    let held_out: Vec<LMBatch> = if let Some(val_dir) = val_shard {
596        let mut val_iter = ShardBatchIter::new(val_dir, batch_size, seq_length, 0, 0)
597            .map_err(|e| {
598                CliError::ValidationFailed(format!(
599                    "FALSIFY-PRETRAIN-VAL-SHARD-001: --val-shard iterator init failed: {e} \
600                     (path={})",
601                    val_dir.display()
602                ))
603            })?
604            // Per INV-PRETRAIN-VAL-SHARD-002 — the val shard is NOT
605            // wrap-around. A short val corpus draws short held_out
606            // (potentially < HELD_OUT_BATCHES batches) and the run
607            // proceeds; we only fail if zero batches are drawn.
608            .with_wrap_around(false);
609        let mut batches: Vec<LMBatch> = Vec::with_capacity(HELD_OUT_BATCHES);
610        for _ in 0..HELD_OUT_BATCHES {
611            match val_iter.next() {
612                Some(b) => batches.push(b),
613                None => break,
614            }
615        }
616        if batches.is_empty() {
617            return Err(CliError::ValidationFailed(format!(
618                "FALSIFY-PRETRAIN-VAL-SHARD-003: --val-shard {} is too small to yield any \
619                 held-out batches at batch_size={} seq_length={}",
620                val_dir.display(),
621                batch_size,
622                seq_length
623            )));
624        }
625        if !json_output {
626            eprintln!(
627                "[P2-F] held-out val source = --val-shard {} ({} batches)",
628                val_dir.display(),
629                batches.len()
630            );
631        }
632        batches
633    } else {
634        // Reserve the first `HELD_OUT_BATCHES` batches as the held-out val
635        // set; the remainder feeds RealStepFn.
636        let mut batches: Vec<LMBatch> = Vec::with_capacity(HELD_OUT_BATCHES);
637        for _ in 0..HELD_OUT_BATCHES {
638            match iter.next() {
639                Some(b) => batches.push(b),
640                None => break,
641            }
642        }
643        if batches.is_empty() {
644            return Err(CliError::ValidationFailed(format!(
645                "dataset {} is too small to reserve any held-out batches",
646                dataset.display()
647            )));
648        }
649        batches
650    };
651
652    if device.is_cuda() {
653        // §50.4 step 5f.5 SHIPPED (this PR): CUDA path with --init is now
654        // wired symmetric to the CPU path via
655        // `entrenar::train::pretrain_real_cuda::build_shared_cuda_trainer_with_init`.
656        // The same §50.4 step-5f machinery composes through both backends:
657        //   5c: build_transformer_config(init_arch)
658        //   5f.1: validate_pretrain_init_arch_compatible(init_arch) — encoder rejection
659        //   5f.2: load_init_tensors_from_apr(path) — read APR weights
660        //   5f.3: populate_trainer_from_init_tensors(transformer, &tensors) — populate CPU model
661        //   5f.5 (this PR): CudaTransformerTrainer::with_model uploads populated
662        //                   blocks / norm / lm_head to GPU.
663        //
664        // Per `apr-pretrain-arch-polymorphic-v1` v1.7.0 §FALSIFY-APR-PRETRAIN-INIT-CUDA-001,
665        // the const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG is repurposed as a
666        // drift-prevention sentinel — if a future refactor re-introduces a
667        // fail-fast on the CUDA + --init path, the test that pins the const
668        // will fail and surface the regression.
669        drive_real_cuda(
670            config,
671            iter,
672            held_out,
673            lr,
674            seq_length,
675            seed,
676            json_output,
677            init_arch,
678            init_path,
679        )
680    } else {
681        drive_real_cpu(
682            config,
683            iter,
684            held_out,
685            lr,
686            seq_length,
687            seed,
688            json_output,
689            init_arch,
690            init_path,
691        )
692    }
693}
694
695/// CPU backend for `drive_real` — builds a `TransformerTrainer`
696/// (`aprender::Tensor` + trueno SIMD) and wires `RealStepFn` /
697/// `RealValFn` / `AprCheckpointFn`.
698#[allow(clippy::too_many_arguments)]
699fn drive_real_cpu(
700    config: PretrainConfig,
701    iter: entrenar::train::shard_reader::ShardBatchIter,
702    held_out: Vec<LMBatch>,
703    lr: f32,
704    seq_length: usize,
705    seed: u64,
706    json_output: bool,
707    init_arch: Option<&TransformerConfig>,
708    init_path: Option<&Path>,
709) -> Result<RunStatus> {
710    // §50.4 step 5f.4: when --init is set, build the trainer via the
711    // polymorphic builder (extracts arch + loads + populates init tensors).
712    // When --init is absent, use the existing from-scratch baseline builder
713    // so the §24/§25 evidence remains regression-free.
714    let trainer = if init_arch.is_some() || init_path.is_some() {
715        build_shared_trainer_with_init(lr, seq_length, seed, init_arch, init_path)
716            .map_err(CliError::ValidationFailed)?
717    } else {
718        build_shared_trainer(lr, seq_length, seed)
719    };
720    let step_fn = RealStepFn::new(trainer.clone(), Box::new(iter));
721    let val_fn = RealValFn::new(trainer.clone(), held_out);
722    let (ckpt_name, ckpt_arch) = checkpoint_name_and_arch(init_arch);
723    let ckpt: Box<dyn CheckpointFn> =
724        Box::new(AprCheckpointFn::new(trainer, &ckpt_name, &ckpt_arch));
725    run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
726}
727
728/// CUDA backend for `drive_real` — builds a `CudaTransformerTrainer`
729/// and wires `CudaRealStepFn` / `CudaRealValFn` / `CudaAprCheckpointFn`
730/// (task #132 Phase 2, contract gpu-training-backend-v1).
731///
732/// When the `cuda` feature is NOT compiled in, this returns a clear
733/// build-time error so operators who asked for `--device cuda` do not
734/// silently get the CPU path (GATE-GPUTRAIN-002 / FM-GPUTRAIN-SILENT-CPU).
735#[cfg(feature = "cuda")]
736#[allow(clippy::too_many_arguments)]
737fn drive_real_cuda(
738    config: PretrainConfig,
739    iter: entrenar::train::shard_reader::ShardBatchIter,
740    held_out: Vec<LMBatch>,
741    lr: f32,
742    seq_length: usize,
743    seed: u64,
744    json_output: bool,
745    init_arch: Option<&TransformerConfig>,
746    init_path: Option<&Path>,
747) -> Result<RunStatus> {
748    use entrenar::train::pretrain_real_cuda::{
749        build_shared_cuda_trainer, build_shared_cuda_trainer_with_init, CudaAprCheckpointFn,
750        CudaRealStepFn, CudaRealValFn,
751    };
752    // §50.4 step 5f.5: when --init is set on the CUDA path, build via the
753    // polymorphic builder (extracts arch + loads + populates init tensors,
754    // then uploads to GPU). When --init is absent, use the existing
755    // from-scratch baseline so the §24/§25 evidence remains regression-free
756    // and INV-ARCH-370M-001 stays enforced on the from-scratch CUDA path.
757    let trainer = if init_arch.is_some() || init_path.is_some() {
758        build_shared_cuda_trainer_with_init(lr, seq_length, seed, init_arch, init_path).map_err(
759            |e| {
760                CliError::ValidationFailed(format!(
761                    "GATE-GPUTRAIN-002: CUDA trainer allocation (--init path) failed: {e}. \
762                     See contracts/entrenar/gpu-training-backend-v1.yaml and \
763                     contracts/apr-pretrain-arch-polymorphic-v1.yaml v1.7.0 \
764                     §FALSIFY-APR-PRETRAIN-INIT-CUDA-001 — this path is only \
765                     reachable when the binary was built with `--features cuda`.",
766                ))
767            },
768        )?
769    } else {
770        build_shared_cuda_trainer(lr, seq_length, seed).map_err(|e| {
771            CliError::ValidationFailed(format!(
772                "GATE-GPUTRAIN-002: CUDA trainer allocation failed: {e}. \
773                 See contracts/entrenar/gpu-training-backend-v1.yaml and \
774                 memory/feedback_cuda_feature_footgun.md — this path is \
775                 only reachable when the binary was built with `--features cuda`.",
776            ))
777        })?
778    };
779    let step_fn = CudaRealStepFn::new(trainer.clone(), Box::new(iter));
780    let val_fn = CudaRealValFn::new(trainer.clone(), held_out);
781    // SPEC-SHIP-TWO-001 §81 P0-D: pass --tokenizer through so each
782    // checkpoint embeds the tokenizer.json (apr qa requires this).
783    let (ckpt_name, ckpt_arch) = checkpoint_name_and_arch(init_arch);
784    let ckpt: Box<dyn CheckpointFn> = Box::new(
785        CudaAprCheckpointFn::new(trainer, &ckpt_name, &ckpt_arch)
786            .with_tokenizer_dir(&config.tokenizer_dir),
787    );
788    run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
789}
790
791/// CUDA backend stub when the `cuda` feature is NOT compiled in.
792///
793/// This is the load-bearing gate that prevents FM-GPUTRAIN-SILENT-CPU:
794/// if a user passes `--device cuda` on an apr binary built without
795/// CUDA support, they see a clear "rebuild with --features cuda" error
796/// rather than a 14-minute CPU run masquerading as GPU training
797/// (task #132 lambda-labs incident, 2026-04-21).
798#[cfg(not(feature = "cuda"))]
799#[allow(clippy::too_many_arguments)]
800fn drive_real_cuda(
801    _config: PretrainConfig,
802    _iter: entrenar::train::shard_reader::ShardBatchIter,
803    _held_out: Vec<LMBatch>,
804    _lr: f32,
805    _seq_length: usize,
806    _seed: u64,
807    _json_output: bool,
808    _init_arch: Option<&TransformerConfig>,
809    _init_path: Option<&Path>,
810) -> Result<RunStatus> {
811    Err(CliError::ValidationFailed(
812        "GATE-GPUTRAIN-002: --device cuda was requested but this `apr` \
813         binary was built WITHOUT the `cuda` feature. \
814         Rebuild with `cargo build --release --features cuda` or use \
815         `--device cpu`. See memory/feedback_cuda_feature_footgun.md \
816         (contract gpu-training-backend-v1 / task #132 Phase 2)."
817            .into(),
818    ))
819}
820
821/// Shared helper: construct the `PretrainLoop`, run it, print the
822/// terminal report, and bubble the `RunStatus` back for exit-code
823/// mapping. `checkpoint_fn` — when `Some` — writes an APR file per
824/// epoch that passes GATE-TRAIN-005.
825fn run_and_report<S: StepFn, V: ValFn>(
826    config: PretrainConfig,
827    step_fn: S,
828    val_fn: V,
829    checkpoint_fn: Option<Box<dyn CheckpointFn>>,
830    json_output: bool,
831) -> Result<RunStatus> {
832    let mut loop_ = PretrainLoop::new(config, step_fn, val_fn);
833    if let Some(ckpt) = checkpoint_fn {
834        loop_ = loop_.with_checkpoint_fn(ckpt);
835    }
836    let status = loop_.run();
837    report(&status, &loop_, json_output)?;
838    Ok(status)
839}
840
841fn abort_to_err(abort: &PretrainAbort) -> CliError {
842    match abort {
843        PretrainAbort::Divergence { .. } | PretrainAbort::DivergenceAtEpochZero { .. } => {
844            CliError::ValidationFailed(format!(
845                "GATE-TRAIN-005 ship-blocker fired: {abort}. See \
846                 contracts/training-loop-pretrain-v1.yaml and \
847                 memory/project_ship_two_001_model1_qlora_divergence.md"
848            ))
849        }
850        PretrainAbort::NumericalInstability { .. } => {
851            CliError::ValidationFailed(format!("GATE-TRAIN-007 NaN/Inf guard fired: {abort}"))
852        }
853        PretrainAbort::ThroughputOutOfRange { .. } => CliError::ValidationFailed(format!(
854            "GATE-TRAIN-008 throughput-range guard fired: {abort}"
855        )),
856    }
857}
858
859fn print_header(cfg: &PretrainConfig) {
860    output::header("apr pretrain — SHIP-TWO-001 MODEL-2 training loop");
861    println!();
862    output::section("Configuration");
863    output::kv("  Dataset", cfg.dataset_path.display().to_string());
864    output::kv("  Tokenizer", cfg.tokenizer_dir.display().to_string());
865    output::kv("  Run dir", cfg.run_dir.display().to_string());
866    output::kv("  LR max", format!("{:.2e}", cfg.lr_max));
867    output::kv("  Total steps", cfg.total_steps.to_string());
868    output::kv("  Warmup steps", cfg.warmup_steps.to_string());
869    output::kv(
870        "  Batch × seq",
871        format!("{} × {}", cfg.batch_size, cfg.seq_length),
872    );
873    output::kv("  Steps / epoch", cfg.steps_per_epoch.to_string());
874    output::kv("  Seed", cfg.seed.to_string());
875    output::kv("  Target val_loss", format!("{:.2}", cfg.target_val_loss));
876    println!();
877}
878
879fn report<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
880    status: &RunStatus,
881    loop_: &PretrainLoop<S, V>,
882    json_output: bool,
883) -> Result<()> {
884    if json_output {
885        let report = PretrainReport::from(status, loop_);
886        let json = serde_json::to_string_pretty(&report)
887            .map_err(|e| CliError::InvalidFormat(e.to_string()))?;
888        println!("{json}");
889        return Ok(());
890    }
891
892    output::section("Run Result");
893    match status {
894        RunStatus::Ok {
895            final_val_loss,
896            epochs_completed,
897        } => {
898            println!(
899                "  {} CONVERGED  final val_loss={:.4} after {} epoch(s)",
900                "OK".green().bold(),
901                final_val_loss,
902                epochs_completed
903            );
904        }
905        RunStatus::EarlyStop {
906            best_val_loss,
907            epochs_completed,
908        } => {
909            println!(
910                "  {} EARLY_STOP  best val_loss={:.4} after {} epoch(s)",
911                "OK".yellow().bold(),
912                best_val_loss,
913                epochs_completed
914            );
915        }
916        RunStatus::Aborted(abort) => {
917            println!("  {} ABORTED  {}", "FAIL".red().bold(), abort);
918        }
919    }
920    output::kv("  Steps recorded", loop_.step_metrics().len().to_string());
921    output::kv(
922        "  Epochs recorded",
923        loop_.epoch_artifacts().len().to_string(),
924    );
925    println!();
926    Ok(())
927}
928
929#[derive(serde::Serialize)]
930struct PretrainReport {
931    status: String,
932    detail: Option<String>,
933    final_val_loss: Option<f32>,
934    epochs_completed: usize,
935    steps_recorded: usize,
936    val_loss_history: Vec<f32>,
937    /// Per-step `StepMetrics` captured by `PretrainLoop` (GATE-TRAIN-001
938    /// contract `training-loop-pretrain-v1.yaml::per_step_metrics.required`).
939    ///
940    /// Emitted so downstream consumers can discharge FALSIFY-GPUTRAIN-005
941    /// (step-time < 500 ms on RTX 4090 for 370M) and FALSIFY-GPUTRAIN-006
942    /// (same-seed reproducibility — two cuda:0 runs at seed=0 must match
943    /// on every step's train_loss within `AC_GPUTRAIN_006_MAX_SEED_LOSS_DELTA`
944    /// = 1e-5) directly from the `--json` output, rather than having to
945    /// parse run-dir checkpoint metadata.
946    per_step_metrics: Vec<entrenar::train::pretrain::StepMetrics>,
947}
948
949impl PretrainReport {
950    fn from<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
951        status: &RunStatus,
952        loop_: &PretrainLoop<S, V>,
953    ) -> Self {
954        let (status_name, detail, final_val_loss, epochs_completed) = match status {
955            RunStatus::Ok {
956                final_val_loss,
957                epochs_completed,
958            } => (
959                "OK".to_string(),
960                None,
961                Some(*final_val_loss),
962                *epochs_completed,
963            ),
964            RunStatus::EarlyStop {
965                best_val_loss,
966                epochs_completed,
967            } => (
968                "EARLY_STOP".to_string(),
969                None,
970                Some(*best_val_loss),
971                *epochs_completed,
972            ),
973            RunStatus::Aborted(abort) => (
974                "ABORTED".to_string(),
975                Some(abort.to_string()),
976                None,
977                loop_.epoch_artifacts().len(),
978            ),
979        };
980        PretrainReport {
981            status: status_name,
982            detail,
983            final_val_loss,
984            epochs_completed,
985            steps_recorded: loop_.step_metrics().len(),
986            val_loss_history: loop_.val_loss_history().to_vec(),
987            per_step_metrics: loop_.step_metrics().to_vec(),
988        }
989    }
990}
991
992#[cfg(test)]
993mod tests {
994    use super::*;
995    use tempfile::TempDir;
996
997    /// SPEC §82 P0-H: when `--init` is absent, fall back to historical defaults
998    /// so from-scratch 370M pretrain still produces `llama-370m-pretrain` /
999    /// `LlamaForCausalLM` stamps.
1000    #[test]
1001    fn checkpoint_name_and_arch_default_when_no_init() {
1002        let (name, arch) = checkpoint_name_and_arch(None);
1003        assert_eq!(name, "llama-370m-pretrain");
1004        assert_eq!(arch, "LlamaForCausalLM");
1005    }
1006
1007    /// SPEC §82 P0-H: when `--init` is a Qwen2 model, stamp `qwen2-pretrain`
1008    /// and `Qwen2ForCausalLM` so the qwen2 GGUF family mapper handles the
1009    /// 72 Qwen2 attn biases instead of leaving them as passthrough names.
1010    #[test]
1011    fn checkpoint_name_and_arch_qwen2_init() {
1012        let mut cfg = TransformerConfig::llama2_7b();
1013        cfg.hf_architecture = Some("Qwen2ForCausalLM".to_string());
1014        cfg.hf_model_type = Some("qwen2".to_string());
1015        let (name, arch) = checkpoint_name_and_arch(Some(&cfg));
1016        assert_eq!(name, "qwen2-pretrain");
1017        assert_eq!(arch, "Qwen2ForCausalLM");
1018    }
1019
1020    /// SPEC §82 P0-H: a `--init` model that lacks `hf_architecture` falls back
1021    /// to `LlamaForCausalLM` rather than silently emitting an empty arch
1022    /// string. (Belt-and-suspenders for older APR files written before the
1023    /// hf_architecture field existed.)
1024    #[test]
1025    fn checkpoint_name_and_arch_init_without_hf_fields() {
1026        let cfg = TransformerConfig::llama2_7b();
1027        // llama2_7b() leaves hf_architecture and hf_model_type as None.
1028        let (name, arch) = checkpoint_name_and_arch(Some(&cfg));
1029        assert_eq!(name, "model-pretrain");
1030        assert_eq!(arch, "LlamaForCausalLM");
1031    }
1032
1033    /// Stage a `vocab.json` with exactly `n` distinct integer-string tokens at
1034    /// `<dir>/vocab.json`. Used by pre-flight gate tests + by other tests that
1035    /// need to get PAST the GATE-ARCH-370M-011 pre-flight to exercise a later
1036    /// failure mode (e.g. empty dataset shards).
1037    fn stage_vocab_json(dir: &std::path::Path, n: usize) {
1038        std::fs::create_dir_all(dir).expect("mkdir tokenizer dir");
1039        let mut obj = serde_json::Map::with_capacity(n);
1040        for i in 0..n {
1041            obj.insert(format!("t{i}"), serde_json::Value::from(i as u64));
1042        }
1043        let json = serde_json::to_string(&obj).expect("serialize");
1044        std::fs::write(dir.join("vocab.json"), json).expect("write vocab.json");
1045    }
1046
1047    /// SPEC §82 P1-A: parameter count estimator should be order-of-magnitude
1048    /// correct for known reference models. Qwen2.5-0.5B has ~500M params;
1049    /// our coarse formula should be within 2× of that.
1050    #[test]
1051    fn estimate_param_count_qwen2_05b_within_2x() {
1052        let mut cfg = TransformerConfig::llama2_7b();
1053        cfg.hidden_size = 896;
1054        cfg.num_hidden_layers = 24;
1055        cfg.num_attention_heads = 14;
1056        cfg.num_kv_heads = 2;
1057        cfg.intermediate_size = 4864;
1058        cfg.vocab_size = 151936;
1059        let n = estimate_param_count(&cfg);
1060        // True Qwen2.5-0.5B = ~494M. Our estimate counts tied embedding once
1061        // and ignores GQA reduction; expect ~400-700M.
1062        let ref_params: u64 = 494_000_000;
1063        assert!(
1064            n > ref_params / 2 && n < ref_params * 2,
1065            "Qwen2.5-0.5B estimate {n} should be within 2× of 494M",
1066        );
1067    }
1068
1069    /// SPEC §82 P1-A: estimator should scale super-linearly with depth.
1070    #[test]
1071    fn estimate_param_count_scales_with_layers() {
1072        let mut cfg = TransformerConfig::llama2_7b();
1073        cfg.hidden_size = 512;
1074        cfg.num_hidden_layers = 1;
1075        cfg.intermediate_size = 2048;
1076        cfg.vocab_size = 32000;
1077        let n1 = estimate_param_count(&cfg);
1078        cfg.num_hidden_layers = 24;
1079        let n24 = estimate_param_count(&cfg);
1080        // 24× per-layer params + shared embedding ≈ 5-6× total for small models
1081        // where embedding dominates per-layer contribution.
1082        assert!(
1083            n24 > n1 * 4,
1084            "24-layer model {n24} should be at least 4× 1-layer model {n1}",
1085        );
1086    }
1087
1088    // ─── SPEC §83 P0-J: Chinchilla hard-gate behavior ──────────
1089    //
1090    // The gate logic itself lives inline in `run()` so a full unit
1091    // test requires either calling `run()` (heavy — needs dataset
1092    // path + tokenizer dir) or factoring the math into a helper.
1093    // Below we test the math in isolation via a local helper; the
1094    // end-to-end CLI behavior is covered by integration tests in
1095    // tests/chinchilla_gate_test.rs (FALSIFY-CHINCHILLA-001..003).
1096
1097    /// Mirror of the inline gate math in `run()` — kept in sync via
1098    /// review. Returns Some(error_message) if rejected, None if
1099    /// accepted (with or without bypass).
1100    fn chinchilla_gate_check(
1101        arch: &TransformerConfig,
1102        num_steps: usize,
1103        batch_size: usize,
1104        seq_length: usize,
1105        force_under_provisioned: bool,
1106    ) -> Option<f64> {
1107        let n_params = estimate_param_count(arch);
1108        let d_tokens = (num_steps as u64)
1109            .saturating_mul(batch_size as u64)
1110            .saturating_mul(seq_length as u64);
1111        let ratio = d_tokens as f64 / n_params as f64;
1112        if ratio < 10.0 && !force_under_provisioned {
1113            Some(ratio)
1114        } else {
1115            None
1116        }
1117    }
1118
1119    fn qwen_05b_config() -> TransformerConfig {
1120        let mut cfg = TransformerConfig::llama2_7b();
1121        cfg.hidden_size = 896;
1122        cfg.num_hidden_layers = 24;
1123        cfg.num_attention_heads = 14;
1124        cfg.num_kv_heads = 2;
1125        cfg.intermediate_size = 4864;
1126        cfg.vocab_size = 151936;
1127        cfg.hf_architecture = Some("Qwen2ForCausalLM".to_string());
1128        cfg.hf_model_type = Some("qwen2".to_string());
1129        cfg
1130    }
1131
1132    /// FALSIFY-CHINCHILLA-001 (unit): §82 P2-A reproducer — 5000
1133    /// steps × 16 × 512 = 40.96M tokens against Qwen-0.5B (~494M
1134    /// params) = ratio 0.083× → REJECTED.
1135    #[test]
1136    fn chinchilla_hard_gate_rejects_under_provisioned() {
1137        let cfg = qwen_05b_config();
1138        let verdict = chinchilla_gate_check(&cfg, 5000, 16, 512, false);
1139        assert!(verdict.is_some(), "0.083× should be rejected");
1140        let ratio = verdict.expect("ratio");
1141        assert!(ratio < 0.1, "expected ratio < 0.1, got {ratio}");
1142    }
1143
1144    /// FALSIFY-CHINCHILLA-002 (unit): same config with bypass flag
1145    /// → accepted (returns None despite low ratio).
1146    #[test]
1147    fn chinchilla_hard_gate_bypasses_with_force_flag() {
1148        let cfg = qwen_05b_config();
1149        let verdict = chinchilla_gate_check(&cfg, 5000, 16, 512, true);
1150        assert!(verdict.is_none(), "force_under_provisioned must bypass");
1151    }
1152
1153    /// FALSIFY-CHINCHILLA-004 (unit): boundary at exactly D/N = 10
1154    /// passes; just below fails. Uses ceiling division to ensure
1155    /// the "exact" case actually meets or exceeds 10·N (integer
1156    /// truncation on `target_d / (bs*sl)` would land slightly below).
1157    #[test]
1158    fn chinchilla_hard_gate_boundary_10x() {
1159        let cfg = qwen_05b_config();
1160        let n = estimate_param_count(&cfg);
1161        let bs = 16u64;
1162        let sl = 512u64;
1163        let target_d = 10 * n;
1164        let bs_sl = bs * sl;
1165        // Ceiling division so D ≥ 10·N exactly (passes the gate).
1166        let exact_steps = (target_d + bs_sl - 1) / bs_sl;
1167        let verdict_exact =
1168            chinchilla_gate_check(&cfg, exact_steps as usize, bs as usize, sl as usize, false);
1169        assert!(
1170            verdict_exact.is_none(),
1171            "ratio ≥ 10.0 should PASS, got verdict={verdict_exact:?}"
1172        );
1173        // One full step less → below 10·N → REJECTED.
1174        let verdict_below = chinchilla_gate_check(
1175            &cfg,
1176            (exact_steps - 1) as usize,
1177            bs as usize,
1178            sl as usize,
1179            false,
1180        );
1181        assert!(
1182            verdict_below.is_some(),
1183            "ratio just below 10× should be REJECTED"
1184        );
1185    }
1186
1187    /// FALSIFY-CHINCHILLA-005 (unit): generously-provisioned ratios
1188    /// (≥ 10×) pass without --force flag.
1189    #[test]
1190    fn chinchilla_hard_gate_accepts_well_provisioned() {
1191        let cfg = qwen_05b_config();
1192        let n = estimate_param_count(&cfg);
1193        // 25·N = generous (above 20× compute-optimal target).
1194        let bs = 16u64;
1195        let sl = 512u64;
1196        let steps_25x = ((25 * n) / (bs * sl)) as usize;
1197        let verdict = chinchilla_gate_check(&cfg, steps_25x, bs as usize, sl as usize, false);
1198        assert!(verdict.is_none(), "25× should pass");
1199    }
1200
1201    #[test]
1202    fn preflight_accepts_matching_vocab() {
1203        // GATE-ARCH-370M-011 acceptance case: tokenizer vocab.json with
1204        // exactly Llama370MConfig::VOCAB_SIZE entries must pass pre-flight.
1205        let tmp = TempDir::new().expect("tempdir");
1206        stage_vocab_json(tmp.path(), Llama370MConfig::VOCAB_SIZE);
1207        preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
1208            .expect("matching vocab must pass GATE-ARCH-370M-011");
1209    }
1210
1211    #[test]
1212    fn preflight_rejects_tokenizer_vocab_mismatch() {
1213        // FALSIFY-ARCH-370M-011: a tokenizer whose vocab size drifts from
1214        // the model's pinned VOCAB_SIZE MUST abort dispatch with an error
1215        // message that names both values and the gate id, so the operator
1216        // can see the mismatch without stepping through code. Task #131
1217        // bumped VOCAB_SIZE to 50_257 (Option A) — the counter-example
1218        // below now exercises a tokenizer one token short of contract.
1219        let tmp = TempDir::new().expect("tempdir");
1220        let mismatch = Llama370MConfig::VOCAB_SIZE - 1;
1221        stage_vocab_json(tmp.path(), mismatch);
1222        let err = preflight_tokenizer_vocab_matches_target(
1223            tmp.path(),
1224            Llama370MConfig::VOCAB_SIZE,
1225            false,
1226        )
1227        .expect_err("tokenizer/model vocab mismatch must be rejected");
1228        match err {
1229            CliError::ValidationFailed(msg) => {
1230                assert!(
1231                    msg.contains("GATE-ARCH-370M-011"),
1232                    "msg must cite gate: {msg}"
1233                );
1234                assert!(
1235                    msg.contains(&mismatch.to_string()),
1236                    "msg must name tokenizer vocab: {msg}"
1237                );
1238                assert!(
1239                    msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
1240                    "msg must name model vocab: {msg}"
1241                );
1242            }
1243            other => panic!("unexpected error: {other:?}"),
1244        }
1245    }
1246
1247    #[test]
1248    fn preflight_rejects_missing_vocab_json() {
1249        // Missing vocab.json is a pre-flight failure (not a later shard
1250        // error) — the operator should know the tokenizer layout is
1251        // wrong, not that the dataset is empty.
1252        let tmp = TempDir::new().expect("tempdir");
1253        let err = preflight_tokenizer_vocab_matches_target(
1254            tmp.path(),
1255            Llama370MConfig::VOCAB_SIZE,
1256            false,
1257        )
1258        .expect_err("missing vocab.json must be rejected");
1259        match err {
1260            CliError::ValidationFailed(msg) => {
1261                assert!(
1262                    msg.contains("GATE-ARCH-370M-011"),
1263                    "msg must cite gate: {msg}"
1264                );
1265                assert!(
1266                    msg.contains("cannot read"),
1267                    "msg must name I/O failure: {msg}"
1268                );
1269            }
1270            other => panic!("unexpected error: {other:?}"),
1271        }
1272    }
1273
1274    /// FALSIFY-APR-PRETRAIN-ARCH-005 — a Qwen tokenizer (vocab=151_936) MUST
1275    /// pass preflight when the target_vocab_size is the Qwen extracted-arch
1276    /// (151_936). Falsifies a regression where preflight would still gate
1277    /// against the hardcoded Llama370M vocab.
1278    ///
1279    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5d.
1280    #[test]
1281    fn preflight_qwen_vocab_passes_with_qwen_target() {
1282        const QWEN2_VOCAB_SIZE: usize = 151_936;
1283        let tmp = TempDir::new().expect("tempdir");
1284        stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
1285        // §50.4 step 5d called this with init=Some semantic (the polymorphic path). Use
1286        // init_is_some=true here per §55 relaxed-bound semantics; vocab.len() == target
1287        // is still acceptable under <=.
1288        preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN2_VOCAB_SIZE, true).expect(
1289            "Qwen tokenizer (151_936) MUST pass preflight when target is Qwen-shaped — \
1290             this is the load-bearing claim of §49 fine-tune from a Qwen2.5 init checkpoint",
1291        );
1292    }
1293
1294    /// FALSIFY-APR-PRETRAIN-ARCH-006 — a Qwen tokenizer (vocab=151_936) MUST
1295    /// FAIL preflight when target_vocab_size is the Llama370M baseline
1296    /// (50_257). Falsifies the silent-pass class where an operator would
1297    /// accidentally pair a Qwen tokenizer with the from-scratch trainer.
1298    ///
1299    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5d.
1300    #[test]
1301    fn preflight_qwen_vocab_fails_with_llama_target() {
1302        const QWEN2_VOCAB_SIZE: usize = 151_936;
1303        let tmp = TempDir::new().expect("tempdir");
1304        stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
1305        // §55: this is the from-scratch path (init absent), so init_is_some=false.
1306        // Strict equality applies; tokenizer (151_936) ≠ target (50_257) MUST fail.
1307        let err = preflight_tokenizer_vocab_matches_target(
1308            tmp.path(),
1309            Llama370MConfig::VOCAB_SIZE,
1310            false,
1311        )
1312        .expect_err(
1313            "Qwen tokenizer (151_936) MUST FAIL preflight when target is Llama370M (50_257) — \
1314             silent-pass would corrupt training",
1315        );
1316        match err {
1317            CliError::ValidationFailed(msg) => {
1318                assert!(
1319                    msg.contains(&QWEN2_VOCAB_SIZE.to_string()),
1320                    "msg must name Qwen vocab size 151_936: {msg}"
1321                );
1322                assert!(
1323                    msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
1324                    "msg must name target Llama vocab size 50_257: {msg}"
1325                );
1326            }
1327            other => panic!("unexpected error: {other:?}"),
1328        }
1329    }
1330
1331    /// FALSIFY-APR-PRETRAIN-ARCH-009 (§55) — at preflight level, an HF
1332    /// tokenizer with vocab.json count = 151665 (BPE+added, the §54 LIVE
1333    /// smoke shape) MUST PASS preflight when target is Qwen 151936 AND
1334    /// init_is_some=true (the polymorphic path).
1335    #[test]
1336    fn preflight_qwen_reserved_slots_pass_under_polymorphic_init() {
1337        const QWEN_TOKENIZER_EFFECTIVE: usize = 151_665;
1338        const QWEN_DECLARED_VOCAB: usize = 151_936;
1339        let tmp = TempDir::new().expect("tempdir");
1340        stage_vocab_json(tmp.path(), QWEN_TOKENIZER_EFFECTIVE);
1341
1342        // init_is_some=true: relaxed bound applies; 151665 ≤ 151936 PASSES.
1343        preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, true).expect(
1344            "FALSIFY-APR-PRETRAIN-ARCH-009: HF reserved-slot tokenizer (151_665 ≤ 151_936) \
1345             MUST pass preflight under polymorphic init path (§55 relaxed bound)",
1346        );
1347
1348        // init_is_some=false: strict equality applies; 151665 ≠ 151936 FAILS.
1349        let err = preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, false)
1350            .expect_err(
1351                "FALSIFY-APR-PRETRAIN-ARCH-009 dual: from-scratch path MUST keep strict ==",
1352            );
1353        match err {
1354            CliError::ValidationFailed(msg) => {
1355                assert!(
1356                    msg.contains("GATE-ARCH-370M-011")
1357                        && msg.contains(&QWEN_TOKENIZER_EFFECTIVE.to_string())
1358                        && msg.contains(&QWEN_DECLARED_VOCAB.to_string()),
1359                    "strict-mode error must name gate + both sizes: {msg}"
1360                );
1361            }
1362            other => panic!("unexpected error: {other:?}"),
1363        }
1364    }
1365
1366    /// FALSIFY-APR-PRETRAIN-ARCH-010 (§55) — at preflight level, a tokenizer
1367    /// with MORE entries than the model declares MUST FAIL even under the
1368    /// polymorphic init path. This is the OOB-safety guard: such a tokenizer
1369    /// could emit ids ≥ model_vocab → silent embedding-lookup garbage.
1370    #[test]
1371    fn preflight_oversized_tokenizer_rejected_even_under_polymorphic_init() {
1372        const QWEN_DECLARED_VOCAB: usize = 151_936;
1373        let oversized = QWEN_DECLARED_VOCAB + 100;
1374        let tmp = TempDir::new().expect("tempdir");
1375        stage_vocab_json(tmp.path(), oversized);
1376
1377        let err = preflight_tokenizer_vocab_matches_target(
1378            tmp.path(),
1379            QWEN_DECLARED_VOCAB,
1380            true, // polymorphic path
1381        )
1382        .expect_err(
1383            "FALSIFY-APR-PRETRAIN-ARCH-010: oversized tokenizer MUST fail-fast even under \
1384             polymorphic init (OOB safety; relaxed bound is ≤ not <)",
1385        );
1386        match err {
1387            CliError::ValidationFailed(msg) => {
1388                assert!(
1389                    msg.contains("RELAXED") && msg.contains("OOB"),
1390                    "polymorphic-mode error must cite RELAXED + OOB: {msg}"
1391                );
1392            }
1393            other => panic!("unexpected error: {other:?}"),
1394        }
1395    }
1396
1397    /// FALSIFY-APR-PRETRAIN-INIT-CUDA-001 (drift-prevention sentinel,
1398    /// post-5f.5): after §50.4 step 5f.5 SHIPPED, the const message
1399    /// pins the wireup-is-wired property. The string MUST contain
1400    /// (a) the falsifier id, (b) the canonical "is wired for --device
1401    /// cuda" phrase, (c) a reference to the symmetric builder
1402    /// `build_shared_cuda_trainer_with_init`, and (d) the "5f.5
1403    /// SHIPPED" status marker. If a future refactor accidentally
1404    /// reverts the wireup or renames the symmetric builder, this test
1405    /// catches the drift before the contract reference goes stale.
1406    ///
1407    /// Pinned via `pub(crate) const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG`
1408    /// so this test fires on a CPU-only build (no `--features cuda` needed).
1409    /// The const itself is NOT emitted by any code path in `drive_real`;
1410    /// it survives only to anchor the contract obligation. The runtime
1411    /// behaviour (`drive_real_cuda` calling `build_shared_cuda_trainer_with_init`
1412    /// when `init_arch.is_some() || init_path.is_some()`) is exercised
1413    /// at the entrenar crate level where CUDA-feature builds can fire it.
1414    #[test]
1415    fn drive_real_cuda_init_path_wireup_sentinel_pinned() {
1416        let msg = FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG;
1417        assert!(
1418            msg.contains("FALSIFY-APR-PRETRAIN-INIT-CUDA-001"),
1419            "sentinel MUST cite the falsifier id (auditability): {msg}"
1420        );
1421        assert!(
1422            msg.contains("is wired for --device cuda"),
1423            "sentinel MUST contain the canonical 'is wired' phrase so \
1424             operators recognize §50.4 step 5f.5 SHIPPED: {msg}"
1425        );
1426        assert!(
1427            msg.contains("build_shared_cuda_trainer_with_init"),
1428            "sentinel MUST name the symmetric builder so future agents \
1429             know which symbol implements the wireup: {msg}"
1430        );
1431        assert!(
1432            msg.contains("5f.5 SHIPPED"),
1433            "sentinel MUST include the 5f.5 SHIPPED status marker so \
1434             grep over the codebase can find the discharge point: {msg}"
1435        );
1436    }
1437
1438    #[test]
1439    fn synthetic_pretrain_end_to_end_happy_path() {
1440        let tmp = TempDir::new().expect("tempdir");
1441        let dataset = tmp.path().join("data.jsonl");
1442        let tokenizer = tmp.path().join("tok");
1443        let run_dir = tmp.path().join("run");
1444
1445        let result = run(
1446            &dataset,
1447            &tokenizer,
1448            &run_dir,
1449            PretrainMode::Finetune,
1450            Some(5.0e-5),
1451            25,
1452            Some(5),
1453            2,
1454            4,
1455            5,
1456            42,
1457            Some(2.2),
1458            50257,
1459            true,
1460            "cpu",
1461            None,
1462            false,
1463            None,
1464            true,
1465        );
1466        assert!(
1467            result.is_ok(),
1468            "synthetic pretrain end-to-end must succeed: got {result:?}"
1469        );
1470    }
1471
1472    #[test]
1473    fn real_mode_empty_dataset_dir_errors() {
1474        // When --synthetic is off, the real-corpus branch must surface a
1475        // clear error if the dataset directory has no .bin shards. This
1476        // supersedes the old "non-synthetic is not implemented" guard.
1477        // Stage a valid vocab.json first so GATE-ARCH-370M-011 pre-flight
1478        // passes — otherwise the shard-iterator error below is never reached.
1479        let tmp = TempDir::new().expect("tempdir");
1480        let tok_dir = tmp.path().join("tok");
1481        stage_vocab_json(&tok_dir, Llama370MConfig::VOCAB_SIZE);
1482        let err = run(
1483            tmp.path(),
1484            &tok_dir,
1485            tmp.path(),
1486            PretrainMode::Finetune,
1487            Some(5.0e-5),
1488            10,
1489            Some(2),
1490            2,
1491            4,
1492            5,
1493            42,
1494            Some(2.2),
1495            50257,
1496            false,
1497            "cpu",
1498            None,
1499            false,
1500            None,
1501            true,
1502        )
1503        .expect_err("empty dataset dir must fail to initialise the shard iterator");
1504        match err {
1505            CliError::ValidationFailed(msg) => {
1506                assert!(
1507                    msg.contains("shard iterator init failed"),
1508                    "unexpected message: {msg}"
1509                );
1510            }
1511            other => panic!("unexpected error: {other:?}"),
1512        }
1513    }
1514
1515    #[test]
1516    fn invalid_target_val_loss_rejected() {
1517        let tmp = TempDir::new().expect("tempdir");
1518        let err = run(
1519            tmp.path(),
1520            tmp.path(),
1521            tmp.path(),
1522            PretrainMode::Finetune,
1523            Some(5.0e-5),
1524            10,
1525            Some(2),
1526            2,
1527            4,
1528            5,
1529            42,
1530            Some(-1.0),
1531            50257,
1532            true,
1533            "cpu",
1534            None,
1535            false,
1536            None,
1537            true,
1538        )
1539        .expect_err("negative target_val_loss must be rejected");
1540        assert!(matches!(err, CliError::ValidationFailed(_)));
1541    }
1542
1543    // ── GATE-TRAIN-009 / INV-TRAIN-009 falsifiers ──────────────────────
1544    // Contract: training-loop-pretrain-v1 v1.3.0 §hyperparameter_defaults
1545    //
1546    // These tests bind the CLI's `mode_defaults` resolver to the
1547    // hyperparameter_defaults YAML table. If the table is ever edited
1548    // without also updating this resolver (or vice versa), the tests
1549    // fail. That is exactly the drift INV-TRAIN-009 forbids.
1550
1551    #[test]
1552    fn mode_finetune_is_default_and_matches_contract() {
1553        // No overrides → resolved HP matches the `finetune` YAML row
1554        // (lr_max=5e-5, warmup_steps=100, target_val_loss=2.2) AND the
1555        // regime is Finetune so INV-TRAIN-005 epoch-zero cap = 10.0.
1556        let hp = mode_defaults(PretrainMode::Finetune, 50257, None, None, None);
1557        assert_eq!(hp.regime, TrainingRegime::Finetune);
1558        assert!(
1559            (hp.lr_max - 5.0e-5).abs() < 1.0e-12,
1560            "lr_max={} must equal finetune default 5e-5",
1561            hp.lr_max
1562        );
1563        assert_eq!(hp.warmup_steps, 100);
1564        assert!(
1565            (hp.target_val_loss - 2.2).abs() < 1.0e-6,
1566            "target_val_loss={} must equal finetune default 2.2",
1567            hp.target_val_loss
1568        );
1569    }
1570
1571    #[test]
1572    fn mode_from_scratch_applies_all_four_defaults() {
1573        // `--mode from-scratch` with no HP overrides MUST yield the full
1574        // cold-start 4-tuple atomically — regime=FromScratch, lr=3e-4,
1575        // warmup=1000, target=3.0. INV-TRAIN-009 falsifier (a).
1576        let hp = mode_defaults(PretrainMode::FromScratch, 50257, None, None, None);
1577        assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1578        assert!(
1579            (hp.lr_max - 3.0e-4).abs() < 1.0e-12,
1580            "lr_max={} must equal from_scratch default 3e-4",
1581            hp.lr_max
1582        );
1583        assert_eq!(hp.warmup_steps, 1000);
1584        assert!(
1585            (hp.target_val_loss - 3.0).abs() < 1.0e-6,
1586            "target_val_loss={} must equal from_scratch default 3.0",
1587            hp.target_val_loss
1588        );
1589    }
1590
1591    #[test]
1592    fn mode_from_scratch_honors_explicit_lr_override() {
1593        // `--mode from-scratch --lr 1e-4` → regime still flips to
1594        // FromScratch AND warmup/target keep the from_scratch defaults,
1595        // but lr_max is the operator-supplied 1e-4. INV-TRAIN-009
1596        // falsifier (b): overrides win, regime still moves.
1597        let hp = mode_defaults(PretrainMode::FromScratch, 50257, Some(1.0e-4), None, None);
1598        assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1599        assert!(
1600            (hp.lr_max - 1.0e-4).abs() < 1.0e-12,
1601            "lr_max={} must equal explicit override 1e-4",
1602            hp.lr_max
1603        );
1604        // Remaining two fields retained their mode defaults.
1605        assert_eq!(hp.warmup_steps, 1000);
1606        assert!((hp.target_val_loss - 3.0).abs() < 1.0e-6);
1607    }
1608
1609    // ── GATE-TRAIN-010 / INV-TRAIN-010 falsifiers ──────────────────────
1610    // Contract: training-loop-pretrain-v1 v1.4.0 §INV-TRAIN-010
1611    //
1612    // Task #105's original wiring shipped `synthetic: bool` with
1613    // `default_value = "true"`. The `--synthetic` flag had no
1614    // companion to turn it off, so every invocation of `apr pretrain`
1615    // silently routed to drive_synthetic. Tasks #119 / #124 / #125
1616    // all captured scripted-loss output and mis-labeled it real
1617    // compute. These two tests parse actual argv through clap and
1618    // assert the routing discriminator byte-for-byte.
1619
1620    fn parse_pretrain_synthetic(extra: &[&str]) -> bool {
1621        // The `Commands` enum is large enough in debug builds to overflow
1622        // the default 2 MiB test-thread stack during clap's recursive
1623        // destructuring. Run the parse on a worker thread with a 16 MiB
1624        // stack so this falsifier passes in both debug and release.
1625        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1626        std::thread::Builder::new()
1627            .stack_size(16 * 1024 * 1024)
1628            .spawn(move || {
1629                use clap::Parser;
1630                let mut argv: Vec<String> = vec![
1631                    "apr".to_string(),
1632                    "pretrain".to_string(),
1633                    "--dataset".to_string(),
1634                    "/tmp/_gate_train_010/ds".to_string(),
1635                    "--tokenizer".to_string(),
1636                    "/tmp/_gate_train_010/tok".to_string(),
1637                    "--run-dir".to_string(),
1638                    "/tmp/_gate_train_010/run".to_string(),
1639                ];
1640                argv.extend(extra);
1641                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1642                match *cli.command {
1643                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1644                        synthetic,
1645                        ..
1646                    }) => synthetic,
1647                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1648                }
1649            })
1650            .expect("spawn parse thread")
1651            .join()
1652            .expect("parse thread must not panic")
1653    }
1654
1655    #[test]
1656    fn cli_pretrain_defaults_to_real_compute() {
1657        // Absent `--synthetic` MUST parse to synthetic=false so the
1658        // dispatcher routes through drive_real.
1659        assert!(
1660            !parse_pretrain_synthetic(&[]),
1661            "INV-TRAIN-010: `apr pretrain` (no --synthetic) must parse to synthetic=false"
1662        );
1663    }
1664
1665    #[test]
1666    fn cli_pretrain_synthetic_flag_routes_to_synthetic() {
1667        // `--synthetic` present MUST parse to synthetic=true.
1668        assert!(
1669            parse_pretrain_synthetic(&["--synthetic"]),
1670            "INV-TRAIN-010: `apr pretrain --synthetic` must parse to synthetic=true"
1671        );
1672    }
1673
1674    // ── FALSIFY-GPUTRAIN-001 / 002 CLI surface (contract phase 1) ────
1675    // Contract: gpu-training-backend-v1 §device_dispatch
1676    //
1677    // These tests parse actual `apr pretrain --device …` argv through
1678    // clap and assert the string is surfaced byte-for-byte to the
1679    // dispatcher. `resolve_device()` itself is exercised by
1680    // `aprender-train::train::device::tests` — these tests verify that
1681    // the CLI flag exists and that its default is `auto` (the only
1682    // spec allowed to fall back).
1683
1684    fn parse_pretrain_device(extra: &[&str]) -> String {
1685        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1686        std::thread::Builder::new()
1687            .stack_size(16 * 1024 * 1024)
1688            .spawn(move || {
1689                use clap::Parser;
1690                let mut argv: Vec<String> = vec![
1691                    "apr".to_string(),
1692                    "pretrain".to_string(),
1693                    "--dataset".to_string(),
1694                    "/tmp/_gputrain_device/ds".to_string(),
1695                    "--tokenizer".to_string(),
1696                    "/tmp/_gputrain_device/tok".to_string(),
1697                    "--run-dir".to_string(),
1698                    "/tmp/_gputrain_device/run".to_string(),
1699                ];
1700                argv.extend(extra);
1701                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1702                match *cli.command {
1703                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1704                        device, ..
1705                    }) => device,
1706                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1707                }
1708            })
1709            .expect("spawn parse thread")
1710            .join()
1711            .expect("parse thread must not panic")
1712    }
1713
1714    #[test]
1715    fn cli_pretrain_device_defaults_to_auto() {
1716        // Absent `--device`, the flag MUST parse to `"auto"` — the only
1717        // spec allowed to silently fall back to CPU when CUDA is not
1718        // available. Any other default would violate the contract's
1719        // "explicit request → hard-fail" invariant.
1720        assert_eq!(
1721            parse_pretrain_device(&[]),
1722            "auto",
1723            "gpu-training-backend-v1 INV-GPUTRAIN-002: default --device must be `auto`",
1724        );
1725    }
1726
1727    #[test]
1728    fn cli_pretrain_device_accepts_cpu() {
1729        // `--device cpu` MUST round-trip through clap unchanged.
1730        assert_eq!(parse_pretrain_device(&["--device", "cpu"]), "cpu");
1731    }
1732
1733    #[test]
1734    fn cli_pretrain_device_accepts_cuda_index() {
1735        // `--device cuda:7` MUST round-trip unchanged; grammar
1736        // enforcement happens in `resolve_device`, not at clap.
1737        assert_eq!(parse_pretrain_device(&["--device", "cuda:7"]), "cuda:7");
1738    }
1739
1740    // ── apr-pretrain-from-init-v1 falsifiers ────────────────────────────
1741    // Contract: contracts/apr-pretrain-from-init-v1.yaml v1.0.0 PROPOSED
1742    // Spec: SPEC-SHIP-TWO-001 §49 step 4 — wire `apr pretrain --init`
1743    //
1744    // PARTIAL_ALGORITHM_LEVEL: file-existence + magic-byte checks bind
1745    // FALSIFY-APR-PRETRAIN-INIT-003 / -004; the clap surface binds
1746    // FALSIFY-001 / -007. FALSIFY-005 (arch mismatch), -006 (init_loss
1747    // signal), -009 (optimizer state), -010 (idempotent load) are gated
1748    // on the §49 step 5 weight-load impl. The "valid APR returns
1749    // not-yet-wired" test pins the no-silent-fallback contract: a
1750    // recognised APR cannot be silently ignored.
1751
1752    fn parse_pretrain_init(extra: &[&str]) -> Option<std::path::PathBuf> {
1753        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1754        std::thread::Builder::new()
1755            .stack_size(16 * 1024 * 1024)
1756            .spawn(move || {
1757                use clap::Parser;
1758                let mut argv: Vec<String> = vec![
1759                    "apr".to_string(),
1760                    "pretrain".to_string(),
1761                    "--dataset".to_string(),
1762                    "/tmp/_init_flag/ds".to_string(),
1763                    "--tokenizer".to_string(),
1764                    "/tmp/_init_flag/tok".to_string(),
1765                    "--run-dir".to_string(),
1766                    "/tmp/_init_flag/run".to_string(),
1767                ];
1768                argv.extend(extra);
1769                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1770                match *cli.command {
1771                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1772                        init, ..
1773                    }) => init,
1774                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1775                }
1776            })
1777            .expect("spawn parse thread")
1778            .join()
1779            .expect("parse thread must not panic")
1780    }
1781
1782    /// FALSIFY-APR-PRETRAIN-INIT-001: --init flag exists in clap surface.
1783    #[test]
1784    fn pretrain_init_flag_absent_parses_to_none() {
1785        // Absent --init MUST parse to None. Falsifies a regression where a
1786        // default value silently injects a path the operator never typed.
1787        assert_eq!(
1788            parse_pretrain_init(&[]),
1789            None,
1790            "FALSIFY-APR-PRETRAIN-INIT-001/002: default --init must be None (no silent default)"
1791        );
1792    }
1793
1794    /// FALSIFY-APR-PRETRAIN-INIT-001: --init <PATH> parses to Some(PathBuf).
1795    #[test]
1796    fn pretrain_init_flag_parses_path() {
1797        let parsed = parse_pretrain_init(&["--init", "/tmp/foo.apr"]);
1798        assert_eq!(
1799            parsed.as_deref().and_then(|p| p.to_str()),
1800            Some("/tmp/foo.apr"),
1801            "FALSIFY-APR-PRETRAIN-INIT-001: --init <PATH> must round-trip through clap"
1802        );
1803    }
1804
1805    /// FALSIFY-APR-PRETRAIN-INIT-003: --init <missing-file> fails fast
1806    /// before any trainer allocation; stderr names the path.
1807    #[test]
1808    fn pretrain_init_missing_file_errors() {
1809        let tmp = TempDir::new().expect("tempdir");
1810        let missing = tmp.path().join("does-not-exist.apr");
1811        let err = run(
1812            tmp.path(),
1813            tmp.path(),
1814            tmp.path(),
1815            PretrainMode::Finetune,
1816            Some(5.0e-5),
1817            10,
1818            Some(2),
1819            2,
1820            4,
1821            5,
1822            42,
1823            Some(2.2),
1824            50257,
1825            true,
1826            "cpu",
1827            Some(&missing),
1828            false,
1829            None,
1830            true,
1831        )
1832        .expect_err("missing --init file must be rejected");
1833        match err {
1834            CliError::ValidationFailed(msg) => {
1835                assert!(
1836                    msg.contains("FALSIFY-APR-PRETRAIN-INIT-003"),
1837                    "msg must cite falsifier id: {msg}"
1838                );
1839                assert!(
1840                    msg.contains("does-not-exist.apr"),
1841                    "msg must name the missing path: {msg}"
1842                );
1843            }
1844            other => panic!("unexpected error: {other:?}"),
1845        }
1846    }
1847
1848    /// FALSIFY-APR-PRETRAIN-INIT-004: --init with wrong magic bytes fails fast.
1849    #[test]
1850    fn pretrain_init_bad_magic_errors() {
1851        let tmp = TempDir::new().expect("tempdir");
1852        let bad = tmp.path().join("not-an-apr.bin");
1853        std::fs::write(&bad, b"GGUF\x00\x00\x00\x00\x00\x00\x00\x00").expect("write fixture file");
1854        let err = run(
1855            tmp.path(),
1856            tmp.path(),
1857            tmp.path(),
1858            PretrainMode::Finetune,
1859            Some(5.0e-5),
1860            10,
1861            Some(2),
1862            2,
1863            4,
1864            5,
1865            42,
1866            Some(2.2),
1867            50257,
1868            true,
1869            "cpu",
1870            Some(&bad),
1871            false,
1872            None,
1873            true,
1874        )
1875        .expect_err("invalid magic bytes must be rejected");
1876        match err {
1877            CliError::ValidationFailed(msg) => {
1878                assert!(
1879                    msg.contains("FALSIFY-APR-PRETRAIN-INIT-004"),
1880                    "msg must cite falsifier id: {msg}"
1881                );
1882                assert!(
1883                    msg.contains("not a valid APR file"),
1884                    "msg must describe magic mismatch: {msg}"
1885                );
1886            }
1887            other => panic!("unexpected error: {other:?}"),
1888        }
1889    }
1890
1891    /// FALSIFY-APR-PRETRAIN-INIT-004: empty file (read_exact fails on 4 bytes).
1892    #[test]
1893    fn pretrain_init_empty_file_errors() {
1894        let tmp = TempDir::new().expect("tempdir");
1895        let empty = tmp.path().join("empty.apr");
1896        std::fs::write(&empty, b"").expect("write empty fixture");
1897        let err = run(
1898            tmp.path(),
1899            tmp.path(),
1900            tmp.path(),
1901            PretrainMode::Finetune,
1902            Some(5.0e-5),
1903            10,
1904            Some(2),
1905            2,
1906            4,
1907            5,
1908            42,
1909            Some(2.2),
1910            50257,
1911            true,
1912            "cpu",
1913            Some(&empty),
1914            false,
1915            None,
1916            true,
1917        )
1918        .expect_err("empty file must be rejected (cannot contain magic bytes)");
1919        assert!(matches!(err, CliError::ValidationFailed(_)));
1920    }
1921
1922    /// §50.4 step 5f.4: a magic-byte-valid but metadata-bogus APR file
1923    /// MUST be rejected at the architecture-extraction step, not silently
1924    /// fall back to random init. The error must clearly cite the
1925    /// architecture-extraction failure (not the legacy "not yet wired"
1926    /// guard, which was retired when the wireup landed). This drift-prevention
1927    /// pins the new fail-closed semantic.
1928    #[test]
1929    fn pretrain_init_valid_magic_but_bogus_metadata_fails_at_arch_extraction() {
1930        let tmp = TempDir::new().expect("tempdir");
1931        let valid = tmp.path().join("v2-valid-magic-bogus-metadata.apr");
1932        // APR\0 magic + padding; passes validate_init_apr_path but
1933        // read_apr_architecture (which reads the v2 header) will return None.
1934        std::fs::write(&valid, b"APR\x00\x00\x00\x00\x00\x00\x00\x00\x00")
1935            .expect("write fixture file");
1936        let err = run(
1937            tmp.path(),
1938            tmp.path(),
1939            tmp.path(),
1940            PretrainMode::Finetune,
1941            Some(5.0e-5),
1942            10,
1943            Some(2),
1944            2,
1945            4,
1946            5,
1947            42,
1948            Some(2.2),
1949            50257,
1950            true,
1951            "cpu",
1952            Some(&valid),
1953            false,
1954            None,
1955            true,
1956        )
1957        .expect_err("bogus metadata must NOT silently random-init");
1958        match err {
1959            CliError::ValidationFailed(msg) => {
1960                assert!(
1961                    !msg.contains("not yet wired"),
1962                    "the legacy step-5-partial guard must be retired: {msg}"
1963                );
1964                // The actual error from read_apr_architecture failure or
1965                // downstream layer; both are acceptable as long as we DON'T
1966                // silently load random init.
1967            }
1968            other => panic!("unexpected error: {other:?}"),
1969        }
1970    }
1971
1972    /// Pin v1 magic (APRN) acceptance — `validate_init_apr_path` alone
1973    /// (decoupled from architecture extraction) returns Ok for both APR\0
1974    /// and APRN magic bytes. Architecture extraction is a separate step.
1975    #[test]
1976    fn pretrain_init_v1_magic_aprn_passes_validate_init_apr_path() {
1977        let tmp = TempDir::new().expect("tempdir");
1978        let v1 = tmp.path().join("v1-aprn.apr");
1979        std::fs::write(&v1, b"APRN\x00\x00\x00\x00").expect("write fixture file");
1980        let result = validate_init_apr_path(&v1);
1981        assert!(
1982            result.is_ok(),
1983            "APRN magic must pass validate_init_apr_path; got {result:?}"
1984        );
1985    }
1986}