Skip to main content

apr_cli/commands/
pretrain.rs

1//! `apr pretrain` — pretraining loop driver for SHIP-TWO-001 MODEL-2.
2//!
3//! Wires `entrenar::train::pretrain::PretrainLoop` into the CLI. The
4//! loop shape is enforced by `contracts/training-loop-pretrain-v1.yaml`
5//! — specifically GATE-TRAIN-005 (divergence), GATE-TRAIN-007 (NaN),
6//! and GATE-TRAIN-008 (throughput range).
7//!
8//! For MODEL-2 specifically, the 370M model forward pass is still a
9//! scaffold (see `crates/aprender-train/src/models/llama_370m.rs`),
10//! so this command runs in **synthetic** mode by default: it drives
11//! the loop with a deterministic decreasing-loss step function so the
12//! contract gates are exercised end-to-end even before the 370M
13//! compute path is wired.
14
15use crate::error::{CliError, Result};
16use crate::output;
17use clap::ValueEnum;
18use colored::Colorize;
19use entrenar::models::llama_370m::{
20    assert_tokenizer_vocab_matches_model, assert_tokenizer_vocab_within_model_bound,
21    Llama370MConfig,
22};
23use entrenar::train::device::{resolve_device, Device};
24use entrenar::train::pretrain::{
25    CheckpointFn, LinearDecaySynthetic, PretrainAbort, PretrainConfig, PretrainLoop, RunStatus,
26    ScriptedVal, StepFn, TrainingRegime, ValFn,
27};
28use entrenar::train::pretrain_real::{
29    build_shared_trainer, build_shared_trainer_with_init, AprCheckpointFn, RealStepFn, RealValFn,
30};
31use entrenar::train::shard_reader::ShardBatchIter;
32use entrenar::train::transformer_trainer::LMBatch;
33use entrenar::transformer::TransformerConfig;
34use std::path::Path;
35
36/// Number of LMBatches pulled off the head of the shard stream and
37/// reserved as the held-out validation set.
38///
39/// 2026-04-26: bumped from 2 → 16 to reduce val_loss measurement
40/// noise on from-scratch runs. With batch=16 seq=512, the prior
41/// 2-batch held-out covered just 16,384 tokens — single-batch
42/// fluctuation was ~0.04 in val_loss, which is at the same scale
43/// as epoch-over-epoch improvement signal during early training.
44/// A 50K-step run early-stopped at epoch 5/24 even though
45/// train_loss was monotonically decreasing (10.01 → 9.54). With 16
46/// held-out batches (131K tokens), val_loss noise floor drops
47/// proportionally to ~0.01, restoring early-stop signal-to-noise.
48const HELD_OUT_BATCHES: usize = 16;
49
50/// Drift-prevention constant pinned by `apr-pretrain-arch-polymorphic-v1`
51/// v1.7.0 §FALSIFY-APR-PRETRAIN-INIT-CUDA-001.
52///
53/// Pre-§50.4-step-5f.5 (this constant's first incarnation, v1.4.0..v1.6.0):
54/// the fail-fast error returned when `--init <PATH>` AND `--device cuda`
55/// were combined and the CUDA wireup did not exist. The const was the
56/// drift-prevention surface — a unit test verified the citation, the
57/// "not yet wired" phrase, and the 5f.5 reference all appeared.
58///
59/// Post-5f.5 (this PR — `apr-pretrain-arch-polymorphic-v1` v1.7.0): the
60/// CUDA wireup landed via `entrenar::train::pretrain_real_cuda::
61/// build_shared_cuda_trainer_with_init` (symmetric to the CPU
62/// `build_shared_trainer_with_init`). The const is RETAINED but its
63/// payload is repurposed as a drift-prevention sentinel: if a future
64/// refactor accidentally re-introduces a fail-fast on the CUDA + --init
65/// path, the test that pins this string will fail-fast and surface the
66/// regression. The string itself is no longer emitted by any code path
67/// in `drive_real`; it survives only to anchor the contract obligation.
68pub(crate) const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG: &str =
69    "FALSIFY-APR-PRETRAIN-INIT-CUDA-001: --init is wired for --device cuda \
70     via build_shared_cuda_trainer_with_init (5f.5 SHIPPED); operator can pass \
71     --init <PATH> --device cuda for end-to-end GPU fine-tune dispatch.";
72
73/// CLI selector bound to training-loop-pretrain-v1 §hyperparameter_defaults.
74/// Atomically flips the `(regime, lr_max, warmup_steps, target_val_loss)`
75/// 4-tuple per INV-TRAIN-009. Explicit `--lr` / `--warmup-steps` /
76/// `--target-val-loss` still win over the table row.
77#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
78pub enum PretrainMode {
79    /// Post-divergence MODEL-1 remedy defaults (lr=5e-5, warmup=100, target=2.2).
80    Finetune,
81    /// 370M cold-start defaults (lr=3e-4, warmup=1000, target=3.0).
82    FromScratch,
83}
84
85/// Resolved HP tuple from the contract's `hyperparameter_defaults` table.
86/// Inputs are CLI-provided overrides (`None` means "inherit mode default").
87/// Output binds INV-TRAIN-009: regime ALWAYS matches `mode`, and any field
88/// the operator set explicitly passes through unchanged.
89#[derive(Clone, Debug, PartialEq)]
90pub(crate) struct ResolvedHp {
91    pub regime: TrainingRegime,
92    pub lr_max: f32,
93    pub warmup_steps: usize,
94    pub target_val_loss: f32,
95}
96
97pub(crate) fn mode_defaults(
98    mode: PretrainMode,
99    vocab_size: u32,
100    lr_override: Option<f32>,
101    warmup_override: Option<usize>,
102    target_override: Option<f32>,
103) -> ResolvedHp {
104    let (regime, lr_def, warmup_def, target_def) = match mode {
105        PretrainMode::Finetune => (TrainingRegime::Finetune, 5.0e-5, 100, 2.2),
106        PretrainMode::FromScratch => (
107            TrainingRegime::FromScratch { vocab_size },
108            3.0e-4,
109            1000,
110            3.0,
111        ),
112    };
113    ResolvedHp {
114        regime,
115        lr_max: lr_override.unwrap_or(lr_def),
116        warmup_steps: warmup_override.unwrap_or(warmup_def),
117        target_val_loss: target_override.unwrap_or(target_def),
118    }
119}
120
121/// Execute `apr pretrain`.
122#[allow(clippy::too_many_arguments)]
123pub(crate) fn run(
124    dataset: &Path,
125    tokenizer: &Path,
126    run_dir: &Path,
127    mode: PretrainMode,
128    lr: Option<f32>,
129    num_steps: usize,
130    warmup_steps: Option<usize>,
131    batch_size: usize,
132    seq_length: usize,
133    steps_per_epoch: usize,
134    seed: u64,
135    target_val_loss: Option<f32>,
136    vocab_size: u32,
137    synthetic: bool,
138    device: &str,
139    init: Option<&Path>,
140    json_output: bool,
141) -> Result<()> {
142    // Contract gpu-training-backend-v1 INV-GPUTRAIN-001 / GATE-GPUTRAIN-002:
143    // parse --device BEFORE any trainer allocation so an invalid spec
144    // or an explicit `cuda` on a CPU-only host fails fast with a clear
145    // diagnostic. Synthetic drive still honours --device (for parity
146    // with real compute) but the stub error surface is identical.
147    let resolved_device =
148        resolve_device(device).map_err(|e| CliError::ValidationFailed(e.to_string()))?;
149
150    // Contract apr-pretrain-from-init-v1 §init_load_semantics + §50.4 step 5f.4:
151    // when --init is present, (1) validate magic bytes, (2) extract
152    // TransformerConfig from the APR header metadata, (3) propagate the
153    // extracted arch through preflight + trainer construction.
154    // Per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature,
155    // missing or unreadable architecture metadata is FAIL-FAST not silent-fallback.
156    let init_arch: Option<TransformerConfig> = if let Some(init_path) = init {
157        validate_init_apr_path(init_path)?;
158        Some(
159            crate::commands::model_config::read_apr_architecture(init_path).ok_or_else(|| {
160                CliError::ValidationFailed(format!(
161                    "FALSIFY-APR-PRETRAIN-INIT-005: --init APR file at {} has missing or invalid \
162                     architecture metadata (hidden_size, num_heads, num_layers, vocab_size, etc). \
163                     Cannot extract TransformerConfig per apr-pretrain-arch-polymorphic-v1 \
164                     §arch_extraction_signature.",
165                    init_path.display()
166                ))
167            })?,
168        )
169    } else {
170        None
171    };
172
173    let hp = mode_defaults(mode, vocab_size, lr, warmup_steps, target_val_loss);
174
175    // Validation: GATE-TRAIN-003 requires target_val_loss > 0.
176    if hp.target_val_loss <= 0.0 {
177        return Err(CliError::ValidationFailed(format!(
178            "target_val_loss must be positive, got {}",
179            hp.target_val_loss
180        )));
181    }
182    if num_steps == 0 {
183        return Err(CliError::ValidationFailed(
184            "num_steps must be > 0".to_string(),
185        ));
186    }
187    if steps_per_epoch == 0 {
188        return Err(CliError::ValidationFailed(
189            "steps_per_epoch must be > 0".to_string(),
190        ));
191    }
192
193    let config = PretrainConfig {
194        dataset_path: dataset.to_path_buf(),
195        tokenizer_dir: tokenizer.to_path_buf(),
196        run_dir: run_dir.to_path_buf(),
197        lr_max: hp.lr_max,
198        lr_min: (hp.lr_max * 1.0e-2).max(1.0e-7),
199        warmup_steps: hp.warmup_steps,
200        total_steps: num_steps,
201        batch_size,
202        seq_length,
203        steps_per_epoch,
204        seed,
205        grad_clip: 1.0,
206        weight_decay: 0.01,
207        target_val_loss: hp.target_val_loss,
208        // Patience widened from 2 → 5 epochs for from-scratch runs (2026-04-26).
209        // Rationale: a 50K-step run early-stopped at epoch 5/24 even though
210        // train_loss was monotonically decreasing 10.01 → 9.54 (Δ=−0.47);
211        // val_loss noise on 16k-token val set (now 131k) had stdev ~0.04,
212        // same scale as epoch-over-epoch improvement signal during early
213        // training. 5 patience epochs gives the optimizer time to push past
214        // local plateaus without ending an obviously-still-converging run.
215        patience_epochs: 5,
216        // Minimum epochs before early-stop. Bumped 1 → 3 so the warmup
217        // window (1000 steps = 1 epoch at 1000 steps_per_epoch, or 0.5
218        // epoch at 2000 steps_per_epoch) plus 1-2 initial epochs of post-
219        // warmup learning are guaranteed to complete before any early-stop
220        // signal is honoured.
221        min_epochs_before_early_stop: 3,
222        regime: hp.regime,
223    };
224
225    if !json_output {
226        print_header(&config);
227        // GATE-GPUTRAIN-002 visibility: print the resolved Device so the
228        // operator can confirm which backend was selected. `auto` is the
229        // only spec that may silently fall back, and this print makes
230        // the fall-back visible at startup.
231        output::kv("  Device", resolved_device.to_string());
232        println!();
233    }
234
235    let status = if synthetic {
236        drive_synthetic(
237            config.clone(),
238            num_steps,
239            steps_per_epoch,
240            hp.target_val_loss,
241            json_output,
242        )?
243    } else {
244        drive_real(
245            config.clone(),
246            dataset,
247            hp.lr_max,
248            seq_length,
249            batch_size,
250            seed,
251            resolved_device,
252            json_output,
253            init_arch.as_ref(),
254            init,
255        )?
256    };
257
258    // Contract: non-OK terminal statuses map to non-zero exit codes so
259    // operators can recognize divergence / NaN from shell `$?`.
260    match status {
261        RunStatus::Aborted(abort) => Err(abort_to_err(&abort)),
262        RunStatus::Ok { .. } | RunStatus::EarlyStop { .. } => Ok(()),
263    }
264}
265
266/// Synthetic drive: deterministic linear-decay `StepFn` and a scripted
267/// val-loss sequence so the full gate surface (GATE-TRAIN-005/007/008)
268/// is exercised end-to-end with no corpus I/O.
269fn drive_synthetic(
270    config: PretrainConfig,
271    num_steps: usize,
272    steps_per_epoch: usize,
273    target_val_loss: f32,
274    json_output: bool,
275) -> Result<RunStatus> {
276    let step_fn = LinearDecaySynthetic {
277        start_loss: (target_val_loss * 2.0).max(1.5),
278        decay_per_step: (target_val_loss * 0.01).max(1.0e-4),
279        grad_norm: 0.8,
280    };
281    let num_epochs = num_steps.div_ceil(steps_per_epoch);
282    let mut sequence = Vec::with_capacity(num_epochs + 2);
283    let start_val = (target_val_loss * 1.8).max(3.0);
284    for i in 0..(num_epochs + 2) {
285        let t = i as f32 / (num_epochs.max(1) as f32);
286        sequence.push(target_val_loss + (start_val - target_val_loss) * (1.0 - t).max(0.0));
287    }
288    let val_fn = ScriptedVal { sequence };
289    // Synthetic drive has no real weights to checkpoint.
290    run_and_report(config, step_fn, val_fn, None, json_output)
291}
292
293/// Contract apr-pretrain-from-init-v1 §init_load_semantics + §init_error_semantics:
294/// validate `--init <PATH>` BEFORE any trainer allocation. Falsifies
295/// FALSIFY-APR-PRETRAIN-INIT-003 (missing-file) + -004 (invalid-magic).
296///
297/// Returns Ok on a valid APR file (existence + magic bytes verified).
298/// Architecture extraction + weight load are §50.4 step 5f.4 — the
299/// caller (`run()`) extracts the config via `model_config::read_apr_architecture`
300/// and passes both to `build_shared_trainer_with_init` per
301/// `apr-pretrain-arch-polymorphic-v1` §init_load_semantics.
302fn validate_init_apr_path(path: &Path) -> Result<()> {
303    let mut file = std::fs::File::open(path).map_err(|e| {
304        CliError::ValidationFailed(format!(
305            "FALSIFY-APR-PRETRAIN-INIT-003: --init path does not exist or is unreadable: {} ({e})",
306            path.display()
307        ))
308    })?;
309    let mut magic = [0u8; 4];
310    use std::io::Read;
311    file.read_exact(&mut magic).map_err(|e| {
312        CliError::ValidationFailed(format!(
313            "FALSIFY-APR-PRETRAIN-INIT-004: --init file too short to contain APR magic bytes: {} ({e})",
314            path.display()
315        ))
316    })?;
317    // APR magic bytes per `crates/aprender-core/src/format/kani_proofs.rs`:
318    //   APR\0 = [0x41, 0x50, 0x52, 0x00] (v2)
319    //   APRN  = [0x41, 0x50, 0x52, 0x4E] (v1)
320    const APR_MAGIC_V2: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
321    const APR_MAGIC_V1: [u8; 4] = [0x41, 0x50, 0x52, 0x4E];
322    if magic != APR_MAGIC_V2 && magic != APR_MAGIC_V1 {
323        return Err(CliError::ValidationFailed(format!(
324            "FALSIFY-APR-PRETRAIN-INIT-004: --init file is not a valid APR file (magic={:02X?}, expected {:02X?} or {:02X?}): {}",
325            magic, APR_MAGIC_V2, APR_MAGIC_V1, path.display()
326        )));
327    }
328    Ok(())
329}
330
331/// GATE-ARCH-370M-011 pre-flight: count the tokenizer's vocabulary entries
332/// from `vocab.json` and assert the count matches `target_vocab_size`
333/// before any trainer allocation.
334///
335/// Per `apr-pretrain-arch-polymorphic-v1` §qwen_tokenizer_vocab_compatibility
336/// (PR #1473), the target is now POLYMORPHIC — when `--init <PATH>` is set,
337/// the caller passes the extracted-arch's vocab_size (e.g., 151_936 for
338/// Qwen2.5-0.5B); otherwise `Llama370MConfig::VOCAB_SIZE` (50_257) for
339/// the §24/§25 from-scratch baseline.
340///
341/// Any mismatch aborts the dispatch with a clear error naming both values
342/// and the violated invariant — the N-09 OOB escape in `Embedding::forward`
343/// would otherwise silently corrupt training.
344///
345/// Discharges FALSIFY-APR-PRETRAIN-ARCH-005 (Qwen tokenizer passes with
346/// Qwen target) and FALSIFY-APR-PRETRAIN-ARCH-006 (Qwen tokenizer fails
347/// with Llama target).
348fn preflight_tokenizer_vocab_matches_target(
349    tokenizer_dir: &Path,
350    target_vocab_size: usize,
351    init_is_some: bool,
352) -> Result<()> {
353    let vocab_path = tokenizer_dir.join("vocab.json");
354    let vocab_json = std::fs::read_to_string(&vocab_path).map_err(|e| {
355        CliError::ValidationFailed(format!(
356            "GATE-ARCH-370M-011 pre-flight: cannot read {} ({e})",
357            vocab_path.display()
358        ))
359    })?;
360    let vocab: serde_json::Map<String, serde_json::Value> = serde_json::from_str(&vocab_json)
361        .map_err(|e| {
362            CliError::ValidationFailed(format!(
363                "GATE-ARCH-370M-011 pre-flight: {} is not a valid vocab.json: {e}",
364                vocab_path.display()
365            ))
366        })?;
367    // §55: when --init is set (polymorphic path with HF-distributed
368    // checkpoint), allow tokenizer_vocab ≤ model_vocab to admit Qwen-style
369    // reserved-slot vocabularies. When --init is absent (§24/§25 from-scratch
370    // baseline), enforce strict equality to preserve INV-ARCH-370M-006.
371    if init_is_some {
372        assert_tokenizer_vocab_within_model_bound(vocab.len(), target_vocab_size)
373            .map_err(CliError::ValidationFailed)
374    } else {
375        assert_tokenizer_vocab_matches_model(vocab.len(), target_vocab_size)
376            .map_err(CliError::ValidationFailed)
377    }
378}
379
380/// Real-corpus drive: build a shared 370M trainer (CPU or CUDA), split
381/// the shard stream head-off into a held-out validation set, and run a
382/// full forward + backward + AdamW step per training batch.
383///
384/// When `device.is_cuda()`, the `cuda` feature must be compiled in —
385/// otherwise this surfaces a clear error rather than silently falling
386/// back to CPU (GATE-GPUTRAIN-002, contract gpu-training-backend-v1).
387#[allow(clippy::too_many_arguments)]
388fn drive_real(
389    config: PretrainConfig,
390    dataset: &Path,
391    lr: f32,
392    seq_length: usize,
393    batch_size: usize,
394    seed: u64,
395    device: Device,
396    json_output: bool,
397    init_arch: Option<&TransformerConfig>,
398    init_path: Option<&Path>,
399) -> Result<RunStatus> {
400    // GATE-ARCH-370M-011 / INV-ARCH-370M-006 — refuse to dispatch a real
401    // training step when the tokenizer vocab_size and the model vocab_size
402    // disagree. The N-09 OOB escape guard in Embedding::forward masks the
403    // mismatch at runtime → silent garbage gradients otherwise. Synthetic
404    // drive skips this check because it never touches the real model.
405    // Per `apr-pretrain-arch-polymorphic-v1` §qwen_tokenizer_vocab_compatibility
406    // (§50.4 step 5d/5f.4): when --init is set, gate by the EXTRACTED arch's
407    // vocab_size; otherwise gate by the §24/§25 baseline Llama370MConfig::VOCAB_SIZE,
408    // preserving regression-free behavior (FALSIFY-002 + FALSIFY-005 + FALSIFY-006).
409    let target_vocab = init_arch
410        .map(|cfg| cfg.vocab_size)
411        .unwrap_or(Llama370MConfig::VOCAB_SIZE);
412    preflight_tokenizer_vocab_matches_target(
413        &config.tokenizer_dir,
414        target_vocab,
415        init_arch.is_some(),
416    )?;
417
418    // MVP: pad_id/eos_id both 0. All sequences are uniform length
419    // (seq_length + 1) so LMBatch::from_sequences takes the shared
420    // layout path and pad_id is never used for padding. The real
421    // tokenizer's special-token ids will plumb through in a follow-up.
422    //
423    // wrap_around=true: when the corpus shards are exhausted before
424    // --num-steps is reached, reset cursor to shard 0 and continue.
425    // This is standard ML-training behaviour (matches PyTorch /
426    // HuggingFace). Without it, an 18M-token corpus exhausts in ~2
427    // epochs of a 5K-step run with batch=16 seq=512, and the
428    // Cuda*StepFn falls back to placeholder loss `(1.0, 1.0)` — silently
429    // producing garbage gradients. See spec §22 (PR #1073) for the
430    // root-cause investigation.
431    let mut iter = ShardBatchIter::new(dataset, batch_size, seq_length, 0, 0)
432        .map_err(|e| {
433            CliError::ValidationFailed(format!(
434                "dataset shard iterator init failed: {e} (path={})",
435                dataset.display()
436            ))
437        })?
438        .with_wrap_around(true);
439
440    // Reserve the first `HELD_OUT_BATCHES` batches as the held-out val
441    // set; the remainder feeds RealStepFn.
442    let mut held_out: Vec<LMBatch> = Vec::with_capacity(HELD_OUT_BATCHES);
443    for _ in 0..HELD_OUT_BATCHES {
444        match iter.next() {
445            Some(b) => held_out.push(b),
446            None => break,
447        }
448    }
449    if held_out.is_empty() {
450        return Err(CliError::ValidationFailed(format!(
451            "dataset {} is too small to reserve any held-out batches",
452            dataset.display()
453        )));
454    }
455
456    if device.is_cuda() {
457        // §50.4 step 5f.5 SHIPPED (this PR): CUDA path with --init is now
458        // wired symmetric to the CPU path via
459        // `entrenar::train::pretrain_real_cuda::build_shared_cuda_trainer_with_init`.
460        // The same §50.4 step-5f machinery composes through both backends:
461        //   5c: build_transformer_config(init_arch)
462        //   5f.1: validate_pretrain_init_arch_compatible(init_arch) — encoder rejection
463        //   5f.2: load_init_tensors_from_apr(path) — read APR weights
464        //   5f.3: populate_trainer_from_init_tensors(transformer, &tensors) — populate CPU model
465        //   5f.5 (this PR): CudaTransformerTrainer::with_model uploads populated
466        //                   blocks / norm / lm_head to GPU.
467        //
468        // Per `apr-pretrain-arch-polymorphic-v1` v1.7.0 §FALSIFY-APR-PRETRAIN-INIT-CUDA-001,
469        // the const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG is repurposed as a
470        // drift-prevention sentinel — if a future refactor re-introduces a
471        // fail-fast on the CUDA + --init path, the test that pins the const
472        // will fail and surface the regression.
473        drive_real_cuda(
474            config,
475            iter,
476            held_out,
477            lr,
478            seq_length,
479            seed,
480            json_output,
481            init_arch,
482            init_path,
483        )
484    } else {
485        drive_real_cpu(
486            config,
487            iter,
488            held_out,
489            lr,
490            seq_length,
491            seed,
492            json_output,
493            init_arch,
494            init_path,
495        )
496    }
497}
498
499/// CPU backend for `drive_real` — builds a `TransformerTrainer`
500/// (`aprender::Tensor` + trueno SIMD) and wires `RealStepFn` /
501/// `RealValFn` / `AprCheckpointFn`.
502#[allow(clippy::too_many_arguments)]
503fn drive_real_cpu(
504    config: PretrainConfig,
505    iter: entrenar::train::shard_reader::ShardBatchIter,
506    held_out: Vec<LMBatch>,
507    lr: f32,
508    seq_length: usize,
509    seed: u64,
510    json_output: bool,
511    init_arch: Option<&TransformerConfig>,
512    init_path: Option<&Path>,
513) -> Result<RunStatus> {
514    // §50.4 step 5f.4: when --init is set, build the trainer via the
515    // polymorphic builder (extracts arch + loads + populates init tensors).
516    // When --init is absent, use the existing from-scratch baseline builder
517    // so the §24/§25 evidence remains regression-free.
518    let trainer = if init_arch.is_some() || init_path.is_some() {
519        build_shared_trainer_with_init(lr, seq_length, seed, init_arch, init_path)
520            .map_err(CliError::ValidationFailed)?
521    } else {
522        build_shared_trainer(lr, seq_length, seed)
523    };
524    let step_fn = RealStepFn::new(trainer.clone(), Box::new(iter));
525    let val_fn = RealValFn::new(trainer.clone(), held_out);
526    let ckpt: Box<dyn CheckpointFn> = Box::new(AprCheckpointFn::new(
527        trainer,
528        "llama-370m-pretrain",
529        "LlamaForCausalLM",
530    ));
531    run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
532}
533
534/// CUDA backend for `drive_real` — builds a `CudaTransformerTrainer`
535/// and wires `CudaRealStepFn` / `CudaRealValFn` / `CudaAprCheckpointFn`
536/// (task #132 Phase 2, contract gpu-training-backend-v1).
537///
538/// When the `cuda` feature is NOT compiled in, this returns a clear
539/// build-time error so operators who asked for `--device cuda` do not
540/// silently get the CPU path (GATE-GPUTRAIN-002 / FM-GPUTRAIN-SILENT-CPU).
541#[cfg(feature = "cuda")]
542#[allow(clippy::too_many_arguments)]
543fn drive_real_cuda(
544    config: PretrainConfig,
545    iter: entrenar::train::shard_reader::ShardBatchIter,
546    held_out: Vec<LMBatch>,
547    lr: f32,
548    seq_length: usize,
549    seed: u64,
550    json_output: bool,
551    init_arch: Option<&TransformerConfig>,
552    init_path: Option<&Path>,
553) -> Result<RunStatus> {
554    use entrenar::train::pretrain_real_cuda::{
555        build_shared_cuda_trainer, build_shared_cuda_trainer_with_init, CudaAprCheckpointFn,
556        CudaRealStepFn, CudaRealValFn,
557    };
558    // §50.4 step 5f.5: when --init is set on the CUDA path, build via the
559    // polymorphic builder (extracts arch + loads + populates init tensors,
560    // then uploads to GPU). When --init is absent, use the existing
561    // from-scratch baseline so the §24/§25 evidence remains regression-free
562    // and INV-ARCH-370M-001 stays enforced on the from-scratch CUDA path.
563    let trainer = if init_arch.is_some() || init_path.is_some() {
564        build_shared_cuda_trainer_with_init(lr, seq_length, seed, init_arch, init_path).map_err(
565            |e| {
566                CliError::ValidationFailed(format!(
567                    "GATE-GPUTRAIN-002: CUDA trainer allocation (--init path) failed: {e}. \
568                     See contracts/entrenar/gpu-training-backend-v1.yaml and \
569                     contracts/apr-pretrain-arch-polymorphic-v1.yaml v1.7.0 \
570                     §FALSIFY-APR-PRETRAIN-INIT-CUDA-001 — this path is only \
571                     reachable when the binary was built with `--features cuda`.",
572                ))
573            },
574        )?
575    } else {
576        build_shared_cuda_trainer(lr, seq_length, seed).map_err(|e| {
577            CliError::ValidationFailed(format!(
578                "GATE-GPUTRAIN-002: CUDA trainer allocation failed: {e}. \
579                 See contracts/entrenar/gpu-training-backend-v1.yaml and \
580                 memory/feedback_cuda_feature_footgun.md — this path is \
581                 only reachable when the binary was built with `--features cuda`.",
582            ))
583        })?
584    };
585    let step_fn = CudaRealStepFn::new(trainer.clone(), Box::new(iter));
586    let val_fn = CudaRealValFn::new(trainer.clone(), held_out);
587    let ckpt: Box<dyn CheckpointFn> = Box::new(CudaAprCheckpointFn::new(
588        trainer,
589        "llama-370m-pretrain",
590        "LlamaForCausalLM",
591    ));
592    run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
593}
594
595/// CUDA backend stub when the `cuda` feature is NOT compiled in.
596///
597/// This is the load-bearing gate that prevents FM-GPUTRAIN-SILENT-CPU:
598/// if a user passes `--device cuda` on an apr binary built without
599/// CUDA support, they see a clear "rebuild with --features cuda" error
600/// rather than a 14-minute CPU run masquerading as GPU training
601/// (task #132 lambda-labs incident, 2026-04-21).
602#[cfg(not(feature = "cuda"))]
603#[allow(clippy::too_many_arguments)]
604fn drive_real_cuda(
605    _config: PretrainConfig,
606    _iter: entrenar::train::shard_reader::ShardBatchIter,
607    _held_out: Vec<LMBatch>,
608    _lr: f32,
609    _seq_length: usize,
610    _seed: u64,
611    _json_output: bool,
612    _init_arch: Option<&TransformerConfig>,
613    _init_path: Option<&Path>,
614) -> Result<RunStatus> {
615    Err(CliError::ValidationFailed(
616        "GATE-GPUTRAIN-002: --device cuda was requested but this `apr` \
617         binary was built WITHOUT the `cuda` feature. \
618         Rebuild with `cargo build --release --features cuda` or use \
619         `--device cpu`. See memory/feedback_cuda_feature_footgun.md \
620         (contract gpu-training-backend-v1 / task #132 Phase 2)."
621            .into(),
622    ))
623}
624
625/// Shared helper: construct the `PretrainLoop`, run it, print the
626/// terminal report, and bubble the `RunStatus` back for exit-code
627/// mapping. `checkpoint_fn` — when `Some` — writes an APR file per
628/// epoch that passes GATE-TRAIN-005.
629fn run_and_report<S: StepFn, V: ValFn>(
630    config: PretrainConfig,
631    step_fn: S,
632    val_fn: V,
633    checkpoint_fn: Option<Box<dyn CheckpointFn>>,
634    json_output: bool,
635) -> Result<RunStatus> {
636    let mut loop_ = PretrainLoop::new(config, step_fn, val_fn);
637    if let Some(ckpt) = checkpoint_fn {
638        loop_ = loop_.with_checkpoint_fn(ckpt);
639    }
640    let status = loop_.run();
641    report(&status, &loop_, json_output)?;
642    Ok(status)
643}
644
645fn abort_to_err(abort: &PretrainAbort) -> CliError {
646    match abort {
647        PretrainAbort::Divergence { .. } | PretrainAbort::DivergenceAtEpochZero { .. } => {
648            CliError::ValidationFailed(format!(
649                "GATE-TRAIN-005 ship-blocker fired: {abort}. See \
650                 contracts/training-loop-pretrain-v1.yaml and \
651                 memory/project_ship_two_001_model1_qlora_divergence.md"
652            ))
653        }
654        PretrainAbort::NumericalInstability { .. } => {
655            CliError::ValidationFailed(format!("GATE-TRAIN-007 NaN/Inf guard fired: {abort}"))
656        }
657        PretrainAbort::ThroughputOutOfRange { .. } => CliError::ValidationFailed(format!(
658            "GATE-TRAIN-008 throughput-range guard fired: {abort}"
659        )),
660    }
661}
662
663fn print_header(cfg: &PretrainConfig) {
664    output::header("apr pretrain — SHIP-TWO-001 MODEL-2 training loop");
665    println!();
666    output::section("Configuration");
667    output::kv("  Dataset", cfg.dataset_path.display().to_string());
668    output::kv("  Tokenizer", cfg.tokenizer_dir.display().to_string());
669    output::kv("  Run dir", cfg.run_dir.display().to_string());
670    output::kv("  LR max", format!("{:.2e}", cfg.lr_max));
671    output::kv("  Total steps", cfg.total_steps.to_string());
672    output::kv("  Warmup steps", cfg.warmup_steps.to_string());
673    output::kv(
674        "  Batch × seq",
675        format!("{} × {}", cfg.batch_size, cfg.seq_length),
676    );
677    output::kv("  Steps / epoch", cfg.steps_per_epoch.to_string());
678    output::kv("  Seed", cfg.seed.to_string());
679    output::kv("  Target val_loss", format!("{:.2}", cfg.target_val_loss));
680    println!();
681}
682
683fn report<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
684    status: &RunStatus,
685    loop_: &PretrainLoop<S, V>,
686    json_output: bool,
687) -> Result<()> {
688    if json_output {
689        let report = PretrainReport::from(status, loop_);
690        let json = serde_json::to_string_pretty(&report)
691            .map_err(|e| CliError::InvalidFormat(e.to_string()))?;
692        println!("{json}");
693        return Ok(());
694    }
695
696    output::section("Run Result");
697    match status {
698        RunStatus::Ok {
699            final_val_loss,
700            epochs_completed,
701        } => {
702            println!(
703                "  {} CONVERGED  final val_loss={:.4} after {} epoch(s)",
704                "OK".green().bold(),
705                final_val_loss,
706                epochs_completed
707            );
708        }
709        RunStatus::EarlyStop {
710            best_val_loss,
711            epochs_completed,
712        } => {
713            println!(
714                "  {} EARLY_STOP  best val_loss={:.4} after {} epoch(s)",
715                "OK".yellow().bold(),
716                best_val_loss,
717                epochs_completed
718            );
719        }
720        RunStatus::Aborted(abort) => {
721            println!("  {} ABORTED  {}", "FAIL".red().bold(), abort);
722        }
723    }
724    output::kv("  Steps recorded", loop_.step_metrics().len().to_string());
725    output::kv(
726        "  Epochs recorded",
727        loop_.epoch_artifacts().len().to_string(),
728    );
729    println!();
730    Ok(())
731}
732
733#[derive(serde::Serialize)]
734struct PretrainReport {
735    status: String,
736    detail: Option<String>,
737    final_val_loss: Option<f32>,
738    epochs_completed: usize,
739    steps_recorded: usize,
740    val_loss_history: Vec<f32>,
741    /// Per-step `StepMetrics` captured by `PretrainLoop` (GATE-TRAIN-001
742    /// contract `training-loop-pretrain-v1.yaml::per_step_metrics.required`).
743    ///
744    /// Emitted so downstream consumers can discharge FALSIFY-GPUTRAIN-005
745    /// (step-time < 500 ms on RTX 4090 for 370M) and FALSIFY-GPUTRAIN-006
746    /// (same-seed reproducibility — two cuda:0 runs at seed=0 must match
747    /// on every step's train_loss within `AC_GPUTRAIN_006_MAX_SEED_LOSS_DELTA`
748    /// = 1e-5) directly from the `--json` output, rather than having to
749    /// parse run-dir checkpoint metadata.
750    per_step_metrics: Vec<entrenar::train::pretrain::StepMetrics>,
751}
752
753impl PretrainReport {
754    fn from<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
755        status: &RunStatus,
756        loop_: &PretrainLoop<S, V>,
757    ) -> Self {
758        let (status_name, detail, final_val_loss, epochs_completed) = match status {
759            RunStatus::Ok {
760                final_val_loss,
761                epochs_completed,
762            } => (
763                "OK".to_string(),
764                None,
765                Some(*final_val_loss),
766                *epochs_completed,
767            ),
768            RunStatus::EarlyStop {
769                best_val_loss,
770                epochs_completed,
771            } => (
772                "EARLY_STOP".to_string(),
773                None,
774                Some(*best_val_loss),
775                *epochs_completed,
776            ),
777            RunStatus::Aborted(abort) => (
778                "ABORTED".to_string(),
779                Some(abort.to_string()),
780                None,
781                loop_.epoch_artifacts().len(),
782            ),
783        };
784        PretrainReport {
785            status: status_name,
786            detail,
787            final_val_loss,
788            epochs_completed,
789            steps_recorded: loop_.step_metrics().len(),
790            val_loss_history: loop_.val_loss_history().to_vec(),
791            per_step_metrics: loop_.step_metrics().to_vec(),
792        }
793    }
794}
795
796#[cfg(test)]
797mod tests {
798    use super::*;
799    use tempfile::TempDir;
800
801    /// Stage a `vocab.json` with exactly `n` distinct integer-string tokens at
802    /// `<dir>/vocab.json`. Used by pre-flight gate tests + by other tests that
803    /// need to get PAST the GATE-ARCH-370M-011 pre-flight to exercise a later
804    /// failure mode (e.g. empty dataset shards).
805    fn stage_vocab_json(dir: &std::path::Path, n: usize) {
806        std::fs::create_dir_all(dir).expect("mkdir tokenizer dir");
807        let mut obj = serde_json::Map::with_capacity(n);
808        for i in 0..n {
809            obj.insert(format!("t{i}"), serde_json::Value::from(i as u64));
810        }
811        let json = serde_json::to_string(&obj).expect("serialize");
812        std::fs::write(dir.join("vocab.json"), json).expect("write vocab.json");
813    }
814
815    #[test]
816    fn preflight_accepts_matching_vocab() {
817        // GATE-ARCH-370M-011 acceptance case: tokenizer vocab.json with
818        // exactly Llama370MConfig::VOCAB_SIZE entries must pass pre-flight.
819        let tmp = TempDir::new().expect("tempdir");
820        stage_vocab_json(tmp.path(), Llama370MConfig::VOCAB_SIZE);
821        preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
822            .expect("matching vocab must pass GATE-ARCH-370M-011");
823    }
824
825    #[test]
826    fn preflight_rejects_tokenizer_vocab_mismatch() {
827        // FALSIFY-ARCH-370M-011: a tokenizer whose vocab size drifts from
828        // the model's pinned VOCAB_SIZE MUST abort dispatch with an error
829        // message that names both values and the gate id, so the operator
830        // can see the mismatch without stepping through code. Task #131
831        // bumped VOCAB_SIZE to 50_257 (Option A) — the counter-example
832        // below now exercises a tokenizer one token short of contract.
833        let tmp = TempDir::new().expect("tempdir");
834        let mismatch = Llama370MConfig::VOCAB_SIZE - 1;
835        stage_vocab_json(tmp.path(), mismatch);
836        let err = preflight_tokenizer_vocab_matches_target(
837            tmp.path(),
838            Llama370MConfig::VOCAB_SIZE,
839            false,
840        )
841        .expect_err("tokenizer/model vocab mismatch must be rejected");
842        match err {
843            CliError::ValidationFailed(msg) => {
844                assert!(
845                    msg.contains("GATE-ARCH-370M-011"),
846                    "msg must cite gate: {msg}"
847                );
848                assert!(
849                    msg.contains(&mismatch.to_string()),
850                    "msg must name tokenizer vocab: {msg}"
851                );
852                assert!(
853                    msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
854                    "msg must name model vocab: {msg}"
855                );
856            }
857            other => panic!("unexpected error: {other:?}"),
858        }
859    }
860
861    #[test]
862    fn preflight_rejects_missing_vocab_json() {
863        // Missing vocab.json is a pre-flight failure (not a later shard
864        // error) — the operator should know the tokenizer layout is
865        // wrong, not that the dataset is empty.
866        let tmp = TempDir::new().expect("tempdir");
867        let err = preflight_tokenizer_vocab_matches_target(
868            tmp.path(),
869            Llama370MConfig::VOCAB_SIZE,
870            false,
871        )
872        .expect_err("missing vocab.json must be rejected");
873        match err {
874            CliError::ValidationFailed(msg) => {
875                assert!(
876                    msg.contains("GATE-ARCH-370M-011"),
877                    "msg must cite gate: {msg}"
878                );
879                assert!(
880                    msg.contains("cannot read"),
881                    "msg must name I/O failure: {msg}"
882                );
883            }
884            other => panic!("unexpected error: {other:?}"),
885        }
886    }
887
888    /// FALSIFY-APR-PRETRAIN-ARCH-005 — a Qwen tokenizer (vocab=151_936) MUST
889    /// pass preflight when the target_vocab_size is the Qwen extracted-arch
890    /// (151_936). Falsifies a regression where preflight would still gate
891    /// against the hardcoded Llama370M vocab.
892    ///
893    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5d.
894    #[test]
895    fn preflight_qwen_vocab_passes_with_qwen_target() {
896        const QWEN2_VOCAB_SIZE: usize = 151_936;
897        let tmp = TempDir::new().expect("tempdir");
898        stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
899        // §50.4 step 5d called this with init=Some semantic (the polymorphic path). Use
900        // init_is_some=true here per §55 relaxed-bound semantics; vocab.len() == target
901        // is still acceptable under <=.
902        preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN2_VOCAB_SIZE, true).expect(
903            "Qwen tokenizer (151_936) MUST pass preflight when target is Qwen-shaped — \
904             this is the load-bearing claim of §49 fine-tune from a Qwen2.5 init checkpoint",
905        );
906    }
907
908    /// FALSIFY-APR-PRETRAIN-ARCH-006 — a Qwen tokenizer (vocab=151_936) MUST
909    /// FAIL preflight when target_vocab_size is the Llama370M baseline
910    /// (50_257). Falsifies the silent-pass class where an operator would
911    /// accidentally pair a Qwen tokenizer with the from-scratch trainer.
912    ///
913    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5d.
914    #[test]
915    fn preflight_qwen_vocab_fails_with_llama_target() {
916        const QWEN2_VOCAB_SIZE: usize = 151_936;
917        let tmp = TempDir::new().expect("tempdir");
918        stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
919        // §55: this is the from-scratch path (init absent), so init_is_some=false.
920        // Strict equality applies; tokenizer (151_936) ≠ target (50_257) MUST fail.
921        let err = preflight_tokenizer_vocab_matches_target(
922            tmp.path(),
923            Llama370MConfig::VOCAB_SIZE,
924            false,
925        )
926        .expect_err(
927            "Qwen tokenizer (151_936) MUST FAIL preflight when target is Llama370M (50_257) — \
928             silent-pass would corrupt training",
929        );
930        match err {
931            CliError::ValidationFailed(msg) => {
932                assert!(
933                    msg.contains(&QWEN2_VOCAB_SIZE.to_string()),
934                    "msg must name Qwen vocab size 151_936: {msg}"
935                );
936                assert!(
937                    msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
938                    "msg must name target Llama vocab size 50_257: {msg}"
939                );
940            }
941            other => panic!("unexpected error: {other:?}"),
942        }
943    }
944
945    /// FALSIFY-APR-PRETRAIN-ARCH-009 (§55) — at preflight level, an HF
946    /// tokenizer with vocab.json count = 151665 (BPE+added, the §54 LIVE
947    /// smoke shape) MUST PASS preflight when target is Qwen 151936 AND
948    /// init_is_some=true (the polymorphic path).
949    #[test]
950    fn preflight_qwen_reserved_slots_pass_under_polymorphic_init() {
951        const QWEN_TOKENIZER_EFFECTIVE: usize = 151_665;
952        const QWEN_DECLARED_VOCAB: usize = 151_936;
953        let tmp = TempDir::new().expect("tempdir");
954        stage_vocab_json(tmp.path(), QWEN_TOKENIZER_EFFECTIVE);
955
956        // init_is_some=true: relaxed bound applies; 151665 ≤ 151936 PASSES.
957        preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, true).expect(
958            "FALSIFY-APR-PRETRAIN-ARCH-009: HF reserved-slot tokenizer (151_665 ≤ 151_936) \
959             MUST pass preflight under polymorphic init path (§55 relaxed bound)",
960        );
961
962        // init_is_some=false: strict equality applies; 151665 ≠ 151936 FAILS.
963        let err = preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, false)
964            .expect_err(
965                "FALSIFY-APR-PRETRAIN-ARCH-009 dual: from-scratch path MUST keep strict ==",
966            );
967        match err {
968            CliError::ValidationFailed(msg) => {
969                assert!(
970                    msg.contains("GATE-ARCH-370M-011")
971                        && msg.contains(&QWEN_TOKENIZER_EFFECTIVE.to_string())
972                        && msg.contains(&QWEN_DECLARED_VOCAB.to_string()),
973                    "strict-mode error must name gate + both sizes: {msg}"
974                );
975            }
976            other => panic!("unexpected error: {other:?}"),
977        }
978    }
979
980    /// FALSIFY-APR-PRETRAIN-ARCH-010 (§55) — at preflight level, a tokenizer
981    /// with MORE entries than the model declares MUST FAIL even under the
982    /// polymorphic init path. This is the OOB-safety guard: such a tokenizer
983    /// could emit ids ≥ model_vocab → silent embedding-lookup garbage.
984    #[test]
985    fn preflight_oversized_tokenizer_rejected_even_under_polymorphic_init() {
986        const QWEN_DECLARED_VOCAB: usize = 151_936;
987        let oversized = QWEN_DECLARED_VOCAB + 100;
988        let tmp = TempDir::new().expect("tempdir");
989        stage_vocab_json(tmp.path(), oversized);
990
991        let err = preflight_tokenizer_vocab_matches_target(
992            tmp.path(),
993            QWEN_DECLARED_VOCAB,
994            true, // polymorphic path
995        )
996        .expect_err(
997            "FALSIFY-APR-PRETRAIN-ARCH-010: oversized tokenizer MUST fail-fast even under \
998             polymorphic init (OOB safety; relaxed bound is ≤ not <)",
999        );
1000        match err {
1001            CliError::ValidationFailed(msg) => {
1002                assert!(
1003                    msg.contains("RELAXED") && msg.contains("OOB"),
1004                    "polymorphic-mode error must cite RELAXED + OOB: {msg}"
1005                );
1006            }
1007            other => panic!("unexpected error: {other:?}"),
1008        }
1009    }
1010
1011    /// FALSIFY-APR-PRETRAIN-INIT-CUDA-001 (drift-prevention sentinel,
1012    /// post-5f.5): after §50.4 step 5f.5 SHIPPED, the const message
1013    /// pins the wireup-is-wired property. The string MUST contain
1014    /// (a) the falsifier id, (b) the canonical "is wired for --device
1015    /// cuda" phrase, (c) a reference to the symmetric builder
1016    /// `build_shared_cuda_trainer_with_init`, and (d) the "5f.5
1017    /// SHIPPED" status marker. If a future refactor accidentally
1018    /// reverts the wireup or renames the symmetric builder, this test
1019    /// catches the drift before the contract reference goes stale.
1020    ///
1021    /// Pinned via `pub(crate) const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG`
1022    /// so this test fires on a CPU-only build (no `--features cuda` needed).
1023    /// The const itself is NOT emitted by any code path in `drive_real`;
1024    /// it survives only to anchor the contract obligation. The runtime
1025    /// behaviour (`drive_real_cuda` calling `build_shared_cuda_trainer_with_init`
1026    /// when `init_arch.is_some() || init_path.is_some()`) is exercised
1027    /// at the entrenar crate level where CUDA-feature builds can fire it.
1028    #[test]
1029    fn drive_real_cuda_init_path_wireup_sentinel_pinned() {
1030        let msg = FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG;
1031        assert!(
1032            msg.contains("FALSIFY-APR-PRETRAIN-INIT-CUDA-001"),
1033            "sentinel MUST cite the falsifier id (auditability): {msg}"
1034        );
1035        assert!(
1036            msg.contains("is wired for --device cuda"),
1037            "sentinel MUST contain the canonical 'is wired' phrase so \
1038             operators recognize §50.4 step 5f.5 SHIPPED: {msg}"
1039        );
1040        assert!(
1041            msg.contains("build_shared_cuda_trainer_with_init"),
1042            "sentinel MUST name the symmetric builder so future agents \
1043             know which symbol implements the wireup: {msg}"
1044        );
1045        assert!(
1046            msg.contains("5f.5 SHIPPED"),
1047            "sentinel MUST include the 5f.5 SHIPPED status marker so \
1048             grep over the codebase can find the discharge point: {msg}"
1049        );
1050    }
1051
1052    #[test]
1053    fn synthetic_pretrain_end_to_end_happy_path() {
1054        let tmp = TempDir::new().expect("tempdir");
1055        let dataset = tmp.path().join("data.jsonl");
1056        let tokenizer = tmp.path().join("tok");
1057        let run_dir = tmp.path().join("run");
1058
1059        let result = run(
1060            &dataset,
1061            &tokenizer,
1062            &run_dir,
1063            PretrainMode::Finetune,
1064            Some(5.0e-5),
1065            25,
1066            Some(5),
1067            2,
1068            4,
1069            5,
1070            42,
1071            Some(2.2),
1072            50257,
1073            true,
1074            "cpu",
1075            None,
1076            true,
1077        );
1078        assert!(
1079            result.is_ok(),
1080            "synthetic pretrain end-to-end must succeed: got {result:?}"
1081        );
1082    }
1083
1084    #[test]
1085    fn real_mode_empty_dataset_dir_errors() {
1086        // When --synthetic is off, the real-corpus branch must surface a
1087        // clear error if the dataset directory has no .bin shards. This
1088        // supersedes the old "non-synthetic is not implemented" guard.
1089        // Stage a valid vocab.json first so GATE-ARCH-370M-011 pre-flight
1090        // passes — otherwise the shard-iterator error below is never reached.
1091        let tmp = TempDir::new().expect("tempdir");
1092        let tok_dir = tmp.path().join("tok");
1093        stage_vocab_json(&tok_dir, Llama370MConfig::VOCAB_SIZE);
1094        let err = run(
1095            tmp.path(),
1096            &tok_dir,
1097            tmp.path(),
1098            PretrainMode::Finetune,
1099            Some(5.0e-5),
1100            10,
1101            Some(2),
1102            2,
1103            4,
1104            5,
1105            42,
1106            Some(2.2),
1107            50257,
1108            false,
1109            "cpu",
1110            None,
1111            true,
1112        )
1113        .expect_err("empty dataset dir must fail to initialise the shard iterator");
1114        match err {
1115            CliError::ValidationFailed(msg) => {
1116                assert!(
1117                    msg.contains("shard iterator init failed"),
1118                    "unexpected message: {msg}"
1119                );
1120            }
1121            other => panic!("unexpected error: {other:?}"),
1122        }
1123    }
1124
1125    #[test]
1126    fn invalid_target_val_loss_rejected() {
1127        let tmp = TempDir::new().expect("tempdir");
1128        let err = run(
1129            tmp.path(),
1130            tmp.path(),
1131            tmp.path(),
1132            PretrainMode::Finetune,
1133            Some(5.0e-5),
1134            10,
1135            Some(2),
1136            2,
1137            4,
1138            5,
1139            42,
1140            Some(-1.0),
1141            50257,
1142            true,
1143            "cpu",
1144            None,
1145            true,
1146        )
1147        .expect_err("negative target_val_loss must be rejected");
1148        assert!(matches!(err, CliError::ValidationFailed(_)));
1149    }
1150
1151    // ── GATE-TRAIN-009 / INV-TRAIN-009 falsifiers ──────────────────────
1152    // Contract: training-loop-pretrain-v1 v1.3.0 §hyperparameter_defaults
1153    //
1154    // These tests bind the CLI's `mode_defaults` resolver to the
1155    // hyperparameter_defaults YAML table. If the table is ever edited
1156    // without also updating this resolver (or vice versa), the tests
1157    // fail. That is exactly the drift INV-TRAIN-009 forbids.
1158
1159    #[test]
1160    fn mode_finetune_is_default_and_matches_contract() {
1161        // No overrides → resolved HP matches the `finetune` YAML row
1162        // (lr_max=5e-5, warmup_steps=100, target_val_loss=2.2) AND the
1163        // regime is Finetune so INV-TRAIN-005 epoch-zero cap = 10.0.
1164        let hp = mode_defaults(PretrainMode::Finetune, 50257, None, None, None);
1165        assert_eq!(hp.regime, TrainingRegime::Finetune);
1166        assert!(
1167            (hp.lr_max - 5.0e-5).abs() < 1.0e-12,
1168            "lr_max={} must equal finetune default 5e-5",
1169            hp.lr_max
1170        );
1171        assert_eq!(hp.warmup_steps, 100);
1172        assert!(
1173            (hp.target_val_loss - 2.2).abs() < 1.0e-6,
1174            "target_val_loss={} must equal finetune default 2.2",
1175            hp.target_val_loss
1176        );
1177    }
1178
1179    #[test]
1180    fn mode_from_scratch_applies_all_four_defaults() {
1181        // `--mode from-scratch` with no HP overrides MUST yield the full
1182        // cold-start 4-tuple atomically — regime=FromScratch, lr=3e-4,
1183        // warmup=1000, target=3.0. INV-TRAIN-009 falsifier (a).
1184        let hp = mode_defaults(PretrainMode::FromScratch, 50257, None, None, None);
1185        assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1186        assert!(
1187            (hp.lr_max - 3.0e-4).abs() < 1.0e-12,
1188            "lr_max={} must equal from_scratch default 3e-4",
1189            hp.lr_max
1190        );
1191        assert_eq!(hp.warmup_steps, 1000);
1192        assert!(
1193            (hp.target_val_loss - 3.0).abs() < 1.0e-6,
1194            "target_val_loss={} must equal from_scratch default 3.0",
1195            hp.target_val_loss
1196        );
1197    }
1198
1199    #[test]
1200    fn mode_from_scratch_honors_explicit_lr_override() {
1201        // `--mode from-scratch --lr 1e-4` → regime still flips to
1202        // FromScratch AND warmup/target keep the from_scratch defaults,
1203        // but lr_max is the operator-supplied 1e-4. INV-TRAIN-009
1204        // falsifier (b): overrides win, regime still moves.
1205        let hp = mode_defaults(PretrainMode::FromScratch, 50257, Some(1.0e-4), None, None);
1206        assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1207        assert!(
1208            (hp.lr_max - 1.0e-4).abs() < 1.0e-12,
1209            "lr_max={} must equal explicit override 1e-4",
1210            hp.lr_max
1211        );
1212        // Remaining two fields retained their mode defaults.
1213        assert_eq!(hp.warmup_steps, 1000);
1214        assert!((hp.target_val_loss - 3.0).abs() < 1.0e-6);
1215    }
1216
1217    // ── GATE-TRAIN-010 / INV-TRAIN-010 falsifiers ──────────────────────
1218    // Contract: training-loop-pretrain-v1 v1.4.0 §INV-TRAIN-010
1219    //
1220    // Task #105's original wiring shipped `synthetic: bool` with
1221    // `default_value = "true"`. The `--synthetic` flag had no
1222    // companion to turn it off, so every invocation of `apr pretrain`
1223    // silently routed to drive_synthetic. Tasks #119 / #124 / #125
1224    // all captured scripted-loss output and mis-labeled it real
1225    // compute. These two tests parse actual argv through clap and
1226    // assert the routing discriminator byte-for-byte.
1227
1228    fn parse_pretrain_synthetic(extra: &[&str]) -> bool {
1229        // The `Commands` enum is large enough in debug builds to overflow
1230        // the default 2 MiB test-thread stack during clap's recursive
1231        // destructuring. Run the parse on a worker thread with a 16 MiB
1232        // stack so this falsifier passes in both debug and release.
1233        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1234        std::thread::Builder::new()
1235            .stack_size(16 * 1024 * 1024)
1236            .spawn(move || {
1237                use clap::Parser;
1238                let mut argv: Vec<String> = vec![
1239                    "apr".to_string(),
1240                    "pretrain".to_string(),
1241                    "--dataset".to_string(),
1242                    "/tmp/_gate_train_010/ds".to_string(),
1243                    "--tokenizer".to_string(),
1244                    "/tmp/_gate_train_010/tok".to_string(),
1245                    "--run-dir".to_string(),
1246                    "/tmp/_gate_train_010/run".to_string(),
1247                ];
1248                argv.extend(extra);
1249                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1250                match *cli.command {
1251                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1252                        synthetic,
1253                        ..
1254                    }) => synthetic,
1255                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1256                }
1257            })
1258            .expect("spawn parse thread")
1259            .join()
1260            .expect("parse thread must not panic")
1261    }
1262
1263    #[test]
1264    fn cli_pretrain_defaults_to_real_compute() {
1265        // Absent `--synthetic` MUST parse to synthetic=false so the
1266        // dispatcher routes through drive_real.
1267        assert!(
1268            !parse_pretrain_synthetic(&[]),
1269            "INV-TRAIN-010: `apr pretrain` (no --synthetic) must parse to synthetic=false"
1270        );
1271    }
1272
1273    #[test]
1274    fn cli_pretrain_synthetic_flag_routes_to_synthetic() {
1275        // `--synthetic` present MUST parse to synthetic=true.
1276        assert!(
1277            parse_pretrain_synthetic(&["--synthetic"]),
1278            "INV-TRAIN-010: `apr pretrain --synthetic` must parse to synthetic=true"
1279        );
1280    }
1281
1282    // ── FALSIFY-GPUTRAIN-001 / 002 CLI surface (contract phase 1) ────
1283    // Contract: gpu-training-backend-v1 §device_dispatch
1284    //
1285    // These tests parse actual `apr pretrain --device …` argv through
1286    // clap and assert the string is surfaced byte-for-byte to the
1287    // dispatcher. `resolve_device()` itself is exercised by
1288    // `aprender-train::train::device::tests` — these tests verify that
1289    // the CLI flag exists and that its default is `auto` (the only
1290    // spec allowed to fall back).
1291
1292    fn parse_pretrain_device(extra: &[&str]) -> String {
1293        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1294        std::thread::Builder::new()
1295            .stack_size(16 * 1024 * 1024)
1296            .spawn(move || {
1297                use clap::Parser;
1298                let mut argv: Vec<String> = vec![
1299                    "apr".to_string(),
1300                    "pretrain".to_string(),
1301                    "--dataset".to_string(),
1302                    "/tmp/_gputrain_device/ds".to_string(),
1303                    "--tokenizer".to_string(),
1304                    "/tmp/_gputrain_device/tok".to_string(),
1305                    "--run-dir".to_string(),
1306                    "/tmp/_gputrain_device/run".to_string(),
1307                ];
1308                argv.extend(extra);
1309                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1310                match *cli.command {
1311                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1312                        device, ..
1313                    }) => device,
1314                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1315                }
1316            })
1317            .expect("spawn parse thread")
1318            .join()
1319            .expect("parse thread must not panic")
1320    }
1321
1322    #[test]
1323    fn cli_pretrain_device_defaults_to_auto() {
1324        // Absent `--device`, the flag MUST parse to `"auto"` — the only
1325        // spec allowed to silently fall back to CPU when CUDA is not
1326        // available. Any other default would violate the contract's
1327        // "explicit request → hard-fail" invariant.
1328        assert_eq!(
1329            parse_pretrain_device(&[]),
1330            "auto",
1331            "gpu-training-backend-v1 INV-GPUTRAIN-002: default --device must be `auto`",
1332        );
1333    }
1334
1335    #[test]
1336    fn cli_pretrain_device_accepts_cpu() {
1337        // `--device cpu` MUST round-trip through clap unchanged.
1338        assert_eq!(parse_pretrain_device(&["--device", "cpu"]), "cpu");
1339    }
1340
1341    #[test]
1342    fn cli_pretrain_device_accepts_cuda_index() {
1343        // `--device cuda:7` MUST round-trip unchanged; grammar
1344        // enforcement happens in `resolve_device`, not at clap.
1345        assert_eq!(parse_pretrain_device(&["--device", "cuda:7"]), "cuda:7");
1346    }
1347
1348    // ── apr-pretrain-from-init-v1 falsifiers ────────────────────────────
1349    // Contract: contracts/apr-pretrain-from-init-v1.yaml v1.0.0 PROPOSED
1350    // Spec: SPEC-SHIP-TWO-001 §49 step 4 — wire `apr pretrain --init`
1351    //
1352    // PARTIAL_ALGORITHM_LEVEL: file-existence + magic-byte checks bind
1353    // FALSIFY-APR-PRETRAIN-INIT-003 / -004; the clap surface binds
1354    // FALSIFY-001 / -007. FALSIFY-005 (arch mismatch), -006 (init_loss
1355    // signal), -009 (optimizer state), -010 (idempotent load) are gated
1356    // on the §49 step 5 weight-load impl. The "valid APR returns
1357    // not-yet-wired" test pins the no-silent-fallback contract: a
1358    // recognised APR cannot be silently ignored.
1359
1360    fn parse_pretrain_init(extra: &[&str]) -> Option<std::path::PathBuf> {
1361        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1362        std::thread::Builder::new()
1363            .stack_size(16 * 1024 * 1024)
1364            .spawn(move || {
1365                use clap::Parser;
1366                let mut argv: Vec<String> = vec![
1367                    "apr".to_string(),
1368                    "pretrain".to_string(),
1369                    "--dataset".to_string(),
1370                    "/tmp/_init_flag/ds".to_string(),
1371                    "--tokenizer".to_string(),
1372                    "/tmp/_init_flag/tok".to_string(),
1373                    "--run-dir".to_string(),
1374                    "/tmp/_init_flag/run".to_string(),
1375                ];
1376                argv.extend(extra);
1377                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1378                match *cli.command {
1379                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1380                        init, ..
1381                    }) => init,
1382                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1383                }
1384            })
1385            .expect("spawn parse thread")
1386            .join()
1387            .expect("parse thread must not panic")
1388    }
1389
1390    /// FALSIFY-APR-PRETRAIN-INIT-001: --init flag exists in clap surface.
1391    #[test]
1392    fn pretrain_init_flag_absent_parses_to_none() {
1393        // Absent --init MUST parse to None. Falsifies a regression where a
1394        // default value silently injects a path the operator never typed.
1395        assert_eq!(
1396            parse_pretrain_init(&[]),
1397            None,
1398            "FALSIFY-APR-PRETRAIN-INIT-001/002: default --init must be None (no silent default)"
1399        );
1400    }
1401
1402    /// FALSIFY-APR-PRETRAIN-INIT-001: --init <PATH> parses to Some(PathBuf).
1403    #[test]
1404    fn pretrain_init_flag_parses_path() {
1405        let parsed = parse_pretrain_init(&["--init", "/tmp/foo.apr"]);
1406        assert_eq!(
1407            parsed.as_deref().and_then(|p| p.to_str()),
1408            Some("/tmp/foo.apr"),
1409            "FALSIFY-APR-PRETRAIN-INIT-001: --init <PATH> must round-trip through clap"
1410        );
1411    }
1412
1413    /// FALSIFY-APR-PRETRAIN-INIT-003: --init <missing-file> fails fast
1414    /// before any trainer allocation; stderr names the path.
1415    #[test]
1416    fn pretrain_init_missing_file_errors() {
1417        let tmp = TempDir::new().expect("tempdir");
1418        let missing = tmp.path().join("does-not-exist.apr");
1419        let err = run(
1420            tmp.path(),
1421            tmp.path(),
1422            tmp.path(),
1423            PretrainMode::Finetune,
1424            Some(5.0e-5),
1425            10,
1426            Some(2),
1427            2,
1428            4,
1429            5,
1430            42,
1431            Some(2.2),
1432            50257,
1433            true,
1434            "cpu",
1435            Some(&missing),
1436            true,
1437        )
1438        .expect_err("missing --init file must be rejected");
1439        match err {
1440            CliError::ValidationFailed(msg) => {
1441                assert!(
1442                    msg.contains("FALSIFY-APR-PRETRAIN-INIT-003"),
1443                    "msg must cite falsifier id: {msg}"
1444                );
1445                assert!(
1446                    msg.contains("does-not-exist.apr"),
1447                    "msg must name the missing path: {msg}"
1448                );
1449            }
1450            other => panic!("unexpected error: {other:?}"),
1451        }
1452    }
1453
1454    /// FALSIFY-APR-PRETRAIN-INIT-004: --init with wrong magic bytes fails fast.
1455    #[test]
1456    fn pretrain_init_bad_magic_errors() {
1457        let tmp = TempDir::new().expect("tempdir");
1458        let bad = tmp.path().join("not-an-apr.bin");
1459        std::fs::write(&bad, b"GGUF\x00\x00\x00\x00\x00\x00\x00\x00").expect("write fixture file");
1460        let err = run(
1461            tmp.path(),
1462            tmp.path(),
1463            tmp.path(),
1464            PretrainMode::Finetune,
1465            Some(5.0e-5),
1466            10,
1467            Some(2),
1468            2,
1469            4,
1470            5,
1471            42,
1472            Some(2.2),
1473            50257,
1474            true,
1475            "cpu",
1476            Some(&bad),
1477            true,
1478        )
1479        .expect_err("invalid magic bytes must be rejected");
1480        match err {
1481            CliError::ValidationFailed(msg) => {
1482                assert!(
1483                    msg.contains("FALSIFY-APR-PRETRAIN-INIT-004"),
1484                    "msg must cite falsifier id: {msg}"
1485                );
1486                assert!(
1487                    msg.contains("not a valid APR file"),
1488                    "msg must describe magic mismatch: {msg}"
1489                );
1490            }
1491            other => panic!("unexpected error: {other:?}"),
1492        }
1493    }
1494
1495    /// FALSIFY-APR-PRETRAIN-INIT-004: empty file (read_exact fails on 4 bytes).
1496    #[test]
1497    fn pretrain_init_empty_file_errors() {
1498        let tmp = TempDir::new().expect("tempdir");
1499        let empty = tmp.path().join("empty.apr");
1500        std::fs::write(&empty, b"").expect("write empty fixture");
1501        let err = run(
1502            tmp.path(),
1503            tmp.path(),
1504            tmp.path(),
1505            PretrainMode::Finetune,
1506            Some(5.0e-5),
1507            10,
1508            Some(2),
1509            2,
1510            4,
1511            5,
1512            42,
1513            Some(2.2),
1514            50257,
1515            true,
1516            "cpu",
1517            Some(&empty),
1518            true,
1519        )
1520        .expect_err("empty file must be rejected (cannot contain magic bytes)");
1521        assert!(matches!(err, CliError::ValidationFailed(_)));
1522    }
1523
1524    /// §50.4 step 5f.4: a magic-byte-valid but metadata-bogus APR file
1525    /// MUST be rejected at the architecture-extraction step, not silently
1526    /// fall back to random init. The error must clearly cite the
1527    /// architecture-extraction failure (not the legacy "not yet wired"
1528    /// guard, which was retired when the wireup landed). This drift-prevention
1529    /// pins the new fail-closed semantic.
1530    #[test]
1531    fn pretrain_init_valid_magic_but_bogus_metadata_fails_at_arch_extraction() {
1532        let tmp = TempDir::new().expect("tempdir");
1533        let valid = tmp.path().join("v2-valid-magic-bogus-metadata.apr");
1534        // APR\0 magic + padding; passes validate_init_apr_path but
1535        // read_apr_architecture (which reads the v2 header) will return None.
1536        std::fs::write(&valid, b"APR\x00\x00\x00\x00\x00\x00\x00\x00\x00")
1537            .expect("write fixture file");
1538        let err = run(
1539            tmp.path(),
1540            tmp.path(),
1541            tmp.path(),
1542            PretrainMode::Finetune,
1543            Some(5.0e-5),
1544            10,
1545            Some(2),
1546            2,
1547            4,
1548            5,
1549            42,
1550            Some(2.2),
1551            50257,
1552            true,
1553            "cpu",
1554            Some(&valid),
1555            true,
1556        )
1557        .expect_err("bogus metadata must NOT silently random-init");
1558        match err {
1559            CliError::ValidationFailed(msg) => {
1560                assert!(
1561                    !msg.contains("not yet wired"),
1562                    "the legacy step-5-partial guard must be retired: {msg}"
1563                );
1564                // The actual error from read_apr_architecture failure or
1565                // downstream layer; both are acceptable as long as we DON'T
1566                // silently load random init.
1567            }
1568            other => panic!("unexpected error: {other:?}"),
1569        }
1570    }
1571
1572    /// Pin v1 magic (APRN) acceptance — `validate_init_apr_path` alone
1573    /// (decoupled from architecture extraction) returns Ok for both APR\0
1574    /// and APRN magic bytes. Architecture extraction is a separate step.
1575    #[test]
1576    fn pretrain_init_v1_magic_aprn_passes_validate_init_apr_path() {
1577        let tmp = TempDir::new().expect("tempdir");
1578        let v1 = tmp.path().join("v1-aprn.apr");
1579        std::fs::write(&v1, b"APRN\x00\x00\x00\x00").expect("write fixture file");
1580        let result = validate_init_apr_path(&v1);
1581        assert!(
1582            result.is_ok(),
1583            "APRN magic must pass validate_init_apr_path; got {result:?}"
1584        );
1585    }
1586}