Skip to main content

apr_cli/commands/
pretrain.rs

1//! `apr pretrain` — pretraining loop driver for SHIP-TWO-001 MODEL-2.
2//!
3//! Wires `entrenar::train::pretrain::PretrainLoop` into the CLI. The
4//! loop shape is enforced by `contracts/training-loop-pretrain-v1.yaml`
5//! — specifically GATE-TRAIN-005 (divergence), GATE-TRAIN-007 (NaN),
6//! and GATE-TRAIN-008 (throughput range).
7//!
8//! For MODEL-2 specifically, the 370M model forward pass is still a
9//! scaffold (see `crates/aprender-train/src/models/llama_370m.rs`),
10//! so this command runs in **synthetic** mode by default: it drives
11//! the loop with a deterministic decreasing-loss step function so the
12//! contract gates are exercised end-to-end even before the 370M
13//! compute path is wired.
14
15use crate::error::{CliError, Result};
16use crate::output;
17use clap::ValueEnum;
18use colored::Colorize;
19use entrenar::models::llama_370m::{
20    assert_tokenizer_vocab_matches_model, assert_tokenizer_vocab_within_model_bound,
21    Llama370MConfig,
22};
23use entrenar::train::device::{resolve_device, Device};
24use entrenar::train::pretrain::{
25    CheckpointFn, LinearDecaySynthetic, PretrainAbort, PretrainConfig, PretrainLoop, RunStatus,
26    ScriptedVal, StepFn, TrainingRegime, ValFn,
27};
28use entrenar::train::pretrain_real::{
29    build_shared_trainer, build_shared_trainer_with_init, AprCheckpointFn, RealStepFn, RealValFn,
30};
31use entrenar::transformer::TransformerConfig;
32use entrenar::train::shard_reader::ShardBatchIter;
33use entrenar::train::transformer_trainer::LMBatch;
34use std::path::Path;
35
36/// Number of LMBatches pulled off the head of the shard stream and
37/// reserved as the held-out validation set.
38///
39/// 2026-04-26: bumped from 2 → 16 to reduce val_loss measurement
40/// noise on from-scratch runs. With batch=16 seq=512, the prior
41/// 2-batch held-out covered just 16,384 tokens — single-batch
42/// fluctuation was ~0.04 in val_loss, which is at the same scale
43/// as epoch-over-epoch improvement signal during early training.
44/// A 50K-step run early-stopped at epoch 5/24 even though
45/// train_loss was monotonically decreasing (10.01 → 9.54). With 16
46/// held-out batches (131K tokens), val_loss noise floor drops
47/// proportionally to ~0.01, restoring early-stop signal-to-noise.
48const HELD_OUT_BATCHES: usize = 16;
49
50/// Drift-prevention constant pinned by `apr-pretrain-arch-polymorphic-v1`
51/// v1.4.0 §FALSIFY-APR-PRETRAIN-INIT-CUDA-001.
52///
53/// The fail-fast error returned when an operator passes both `--init <PATH>`
54/// AND `--device cuda` while the §50.4 step 5f.5 CUDA wireup is not yet
55/// implemented. Extracted into a `pub(crate) const` so that a unit test
56/// can verify (a) the falsifier id appears, (b) the "not yet wired" phrase
57/// appears, (c) the 5f.5 follow-up reference appears — without needing a
58/// `--features cuda` build to fire the runtime path.
59pub(crate) const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG: &str =
60    "FALSIFY-APR-PRETRAIN-INIT-CUDA-001: --init is not yet wired for --device cuda \
61     (step 5f.5 follow-up); use --device cpu OR omit --init for from-scratch CUDA training.";
62
63/// CLI selector bound to training-loop-pretrain-v1 §hyperparameter_defaults.
64/// Atomically flips the `(regime, lr_max, warmup_steps, target_val_loss)`
65/// 4-tuple per INV-TRAIN-009. Explicit `--lr` / `--warmup-steps` /
66/// `--target-val-loss` still win over the table row.
67#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
68pub enum PretrainMode {
69    /// Post-divergence MODEL-1 remedy defaults (lr=5e-5, warmup=100, target=2.2).
70    Finetune,
71    /// 370M cold-start defaults (lr=3e-4, warmup=1000, target=3.0).
72    FromScratch,
73}
74
75/// Resolved HP tuple from the contract's `hyperparameter_defaults` table.
76/// Inputs are CLI-provided overrides (`None` means "inherit mode default").
77/// Output binds INV-TRAIN-009: regime ALWAYS matches `mode`, and any field
78/// the operator set explicitly passes through unchanged.
79#[derive(Clone, Debug, PartialEq)]
80pub(crate) struct ResolvedHp {
81    pub regime: TrainingRegime,
82    pub lr_max: f32,
83    pub warmup_steps: usize,
84    pub target_val_loss: f32,
85}
86
87pub(crate) fn mode_defaults(
88    mode: PretrainMode,
89    vocab_size: u32,
90    lr_override: Option<f32>,
91    warmup_override: Option<usize>,
92    target_override: Option<f32>,
93) -> ResolvedHp {
94    let (regime, lr_def, warmup_def, target_def) = match mode {
95        PretrainMode::Finetune => (TrainingRegime::Finetune, 5.0e-5, 100, 2.2),
96        PretrainMode::FromScratch => (
97            TrainingRegime::FromScratch { vocab_size },
98            3.0e-4,
99            1000,
100            3.0,
101        ),
102    };
103    ResolvedHp {
104        regime,
105        lr_max: lr_override.unwrap_or(lr_def),
106        warmup_steps: warmup_override.unwrap_or(warmup_def),
107        target_val_loss: target_override.unwrap_or(target_def),
108    }
109}
110
111/// Execute `apr pretrain`.
112#[allow(clippy::too_many_arguments)]
113pub(crate) fn run(
114    dataset: &Path,
115    tokenizer: &Path,
116    run_dir: &Path,
117    mode: PretrainMode,
118    lr: Option<f32>,
119    num_steps: usize,
120    warmup_steps: Option<usize>,
121    batch_size: usize,
122    seq_length: usize,
123    steps_per_epoch: usize,
124    seed: u64,
125    target_val_loss: Option<f32>,
126    vocab_size: u32,
127    synthetic: bool,
128    device: &str,
129    init: Option<&Path>,
130    json_output: bool,
131) -> Result<()> {
132    // Contract gpu-training-backend-v1 INV-GPUTRAIN-001 / GATE-GPUTRAIN-002:
133    // parse --device BEFORE any trainer allocation so an invalid spec
134    // or an explicit `cuda` on a CPU-only host fails fast with a clear
135    // diagnostic. Synthetic drive still honours --device (for parity
136    // with real compute) but the stub error surface is identical.
137    let resolved_device =
138        resolve_device(device).map_err(|e| CliError::ValidationFailed(e.to_string()))?;
139
140    // Contract apr-pretrain-from-init-v1 §init_load_semantics + §50.4 step 5f.4:
141    // when --init is present, (1) validate magic bytes, (2) extract
142    // TransformerConfig from the APR header metadata, (3) propagate the
143    // extracted arch through preflight + trainer construction.
144    // Per `apr-pretrain-arch-polymorphic-v1` §arch_extraction_signature,
145    // missing or unreadable architecture metadata is FAIL-FAST not silent-fallback.
146    let init_arch: Option<TransformerConfig> = if let Some(init_path) = init {
147        validate_init_apr_path(init_path)?;
148        Some(crate::commands::model_config::read_apr_architecture(init_path).ok_or_else(
149            || {
150                CliError::ValidationFailed(format!(
151                    "FALSIFY-APR-PRETRAIN-INIT-005: --init APR file at {} has missing or invalid \
152                     architecture metadata (hidden_size, num_heads, num_layers, vocab_size, etc). \
153                     Cannot extract TransformerConfig per apr-pretrain-arch-polymorphic-v1 \
154                     §arch_extraction_signature.",
155                    init_path.display()
156                ))
157            },
158        )?)
159    } else {
160        None
161    };
162
163    let hp = mode_defaults(mode, vocab_size, lr, warmup_steps, target_val_loss);
164
165    // Validation: GATE-TRAIN-003 requires target_val_loss > 0.
166    if hp.target_val_loss <= 0.0 {
167        return Err(CliError::ValidationFailed(format!(
168            "target_val_loss must be positive, got {}",
169            hp.target_val_loss
170        )));
171    }
172    if num_steps == 0 {
173        return Err(CliError::ValidationFailed(
174            "num_steps must be > 0".to_string(),
175        ));
176    }
177    if steps_per_epoch == 0 {
178        return Err(CliError::ValidationFailed(
179            "steps_per_epoch must be > 0".to_string(),
180        ));
181    }
182
183    let config = PretrainConfig {
184        dataset_path: dataset.to_path_buf(),
185        tokenizer_dir: tokenizer.to_path_buf(),
186        run_dir: run_dir.to_path_buf(),
187        lr_max: hp.lr_max,
188        lr_min: (hp.lr_max * 1.0e-2).max(1.0e-7),
189        warmup_steps: hp.warmup_steps,
190        total_steps: num_steps,
191        batch_size,
192        seq_length,
193        steps_per_epoch,
194        seed,
195        grad_clip: 1.0,
196        weight_decay: 0.01,
197        target_val_loss: hp.target_val_loss,
198        // Patience widened from 2 → 5 epochs for from-scratch runs (2026-04-26).
199        // Rationale: a 50K-step run early-stopped at epoch 5/24 even though
200        // train_loss was monotonically decreasing 10.01 → 9.54 (Δ=−0.47);
201        // val_loss noise on 16k-token val set (now 131k) had stdev ~0.04,
202        // same scale as epoch-over-epoch improvement signal during early
203        // training. 5 patience epochs gives the optimizer time to push past
204        // local plateaus without ending an obviously-still-converging run.
205        patience_epochs: 5,
206        // Minimum epochs before early-stop. Bumped 1 → 3 so the warmup
207        // window (1000 steps = 1 epoch at 1000 steps_per_epoch, or 0.5
208        // epoch at 2000 steps_per_epoch) plus 1-2 initial epochs of post-
209        // warmup learning are guaranteed to complete before any early-stop
210        // signal is honoured.
211        min_epochs_before_early_stop: 3,
212        regime: hp.regime,
213    };
214
215    if !json_output {
216        print_header(&config);
217        // GATE-GPUTRAIN-002 visibility: print the resolved Device so the
218        // operator can confirm which backend was selected. `auto` is the
219        // only spec that may silently fall back, and this print makes
220        // the fall-back visible at startup.
221        output::kv("  Device", resolved_device.to_string());
222        println!();
223    }
224
225    let status = if synthetic {
226        drive_synthetic(
227            config.clone(),
228            num_steps,
229            steps_per_epoch,
230            hp.target_val_loss,
231            json_output,
232        )?
233    } else {
234        drive_real(
235            config.clone(),
236            dataset,
237            hp.lr_max,
238            seq_length,
239            batch_size,
240            seed,
241            resolved_device,
242            json_output,
243            init_arch.as_ref(),
244            init,
245        )?
246    };
247
248    // Contract: non-OK terminal statuses map to non-zero exit codes so
249    // operators can recognize divergence / NaN from shell `$?`.
250    match status {
251        RunStatus::Aborted(abort) => Err(abort_to_err(&abort)),
252        RunStatus::Ok { .. } | RunStatus::EarlyStop { .. } => Ok(()),
253    }
254}
255
256/// Synthetic drive: deterministic linear-decay `StepFn` and a scripted
257/// val-loss sequence so the full gate surface (GATE-TRAIN-005/007/008)
258/// is exercised end-to-end with no corpus I/O.
259fn drive_synthetic(
260    config: PretrainConfig,
261    num_steps: usize,
262    steps_per_epoch: usize,
263    target_val_loss: f32,
264    json_output: bool,
265) -> Result<RunStatus> {
266    let step_fn = LinearDecaySynthetic {
267        start_loss: (target_val_loss * 2.0).max(1.5),
268        decay_per_step: (target_val_loss * 0.01).max(1.0e-4),
269        grad_norm: 0.8,
270    };
271    let num_epochs = num_steps.div_ceil(steps_per_epoch);
272    let mut sequence = Vec::with_capacity(num_epochs + 2);
273    let start_val = (target_val_loss * 1.8).max(3.0);
274    for i in 0..(num_epochs + 2) {
275        let t = i as f32 / (num_epochs.max(1) as f32);
276        sequence.push(target_val_loss + (start_val - target_val_loss) * (1.0 - t).max(0.0));
277    }
278    let val_fn = ScriptedVal { sequence };
279    // Synthetic drive has no real weights to checkpoint.
280    run_and_report(config, step_fn, val_fn, None, json_output)
281}
282
283/// Contract apr-pretrain-from-init-v1 §init_load_semantics + §init_error_semantics:
284/// validate `--init <PATH>` BEFORE any trainer allocation. Falsifies
285/// FALSIFY-APR-PRETRAIN-INIT-003 (missing-file) + -004 (invalid-magic).
286///
287/// Returns Ok on a valid APR file (existence + magic bytes verified).
288/// Architecture extraction + weight load are §50.4 step 5f.4 — the
289/// caller (`run()`) extracts the config via `model_config::read_apr_architecture`
290/// and passes both to `build_shared_trainer_with_init` per
291/// `apr-pretrain-arch-polymorphic-v1` §init_load_semantics.
292fn validate_init_apr_path(path: &Path) -> Result<()> {
293    let mut file = std::fs::File::open(path).map_err(|e| {
294        CliError::ValidationFailed(format!(
295            "FALSIFY-APR-PRETRAIN-INIT-003: --init path does not exist or is unreadable: {} ({e})",
296            path.display()
297        ))
298    })?;
299    let mut magic = [0u8; 4];
300    use std::io::Read;
301    file.read_exact(&mut magic).map_err(|e| {
302        CliError::ValidationFailed(format!(
303            "FALSIFY-APR-PRETRAIN-INIT-004: --init file too short to contain APR magic bytes: {} ({e})",
304            path.display()
305        ))
306    })?;
307    // APR magic bytes per `crates/aprender-core/src/format/kani_proofs.rs`:
308    //   APR\0 = [0x41, 0x50, 0x52, 0x00] (v2)
309    //   APRN  = [0x41, 0x50, 0x52, 0x4E] (v1)
310    const APR_MAGIC_V2: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
311    const APR_MAGIC_V1: [u8; 4] = [0x41, 0x50, 0x52, 0x4E];
312    if magic != APR_MAGIC_V2 && magic != APR_MAGIC_V1 {
313        return Err(CliError::ValidationFailed(format!(
314            "FALSIFY-APR-PRETRAIN-INIT-004: --init file is not a valid APR file (magic={:02X?}, expected {:02X?} or {:02X?}): {}",
315            magic, APR_MAGIC_V2, APR_MAGIC_V1, path.display()
316        )));
317    }
318    Ok(())
319}
320
321/// GATE-ARCH-370M-011 pre-flight: count the tokenizer's vocabulary entries
322/// from `vocab.json` and assert the count matches `target_vocab_size`
323/// before any trainer allocation.
324///
325/// Per `apr-pretrain-arch-polymorphic-v1` §qwen_tokenizer_vocab_compatibility
326/// (PR #1473), the target is now POLYMORPHIC — when `--init <PATH>` is set,
327/// the caller passes the extracted-arch's vocab_size (e.g., 151_936 for
328/// Qwen2.5-0.5B); otherwise `Llama370MConfig::VOCAB_SIZE` (50_257) for
329/// the §24/§25 from-scratch baseline.
330///
331/// Any mismatch aborts the dispatch with a clear error naming both values
332/// and the violated invariant — the N-09 OOB escape in `Embedding::forward`
333/// would otherwise silently corrupt training.
334///
335/// Discharges FALSIFY-APR-PRETRAIN-ARCH-005 (Qwen tokenizer passes with
336/// Qwen target) and FALSIFY-APR-PRETRAIN-ARCH-006 (Qwen tokenizer fails
337/// with Llama target).
338fn preflight_tokenizer_vocab_matches_target(
339    tokenizer_dir: &Path,
340    target_vocab_size: usize,
341    init_is_some: bool,
342) -> Result<()> {
343    let vocab_path = tokenizer_dir.join("vocab.json");
344    let vocab_json = std::fs::read_to_string(&vocab_path).map_err(|e| {
345        CliError::ValidationFailed(format!(
346            "GATE-ARCH-370M-011 pre-flight: cannot read {} ({e})",
347            vocab_path.display()
348        ))
349    })?;
350    let vocab: serde_json::Map<String, serde_json::Value> = serde_json::from_str(&vocab_json)
351        .map_err(|e| {
352            CliError::ValidationFailed(format!(
353                "GATE-ARCH-370M-011 pre-flight: {} is not a valid vocab.json: {e}",
354                vocab_path.display()
355            ))
356        })?;
357    // §55: when --init is set (polymorphic path with HF-distributed
358    // checkpoint), allow tokenizer_vocab ≤ model_vocab to admit Qwen-style
359    // reserved-slot vocabularies. When --init is absent (§24/§25 from-scratch
360    // baseline), enforce strict equality to preserve INV-ARCH-370M-006.
361    if init_is_some {
362        assert_tokenizer_vocab_within_model_bound(vocab.len(), target_vocab_size)
363            .map_err(CliError::ValidationFailed)
364    } else {
365        assert_tokenizer_vocab_matches_model(vocab.len(), target_vocab_size)
366            .map_err(CliError::ValidationFailed)
367    }
368}
369
370/// Real-corpus drive: build a shared 370M trainer (CPU or CUDA), split
371/// the shard stream head-off into a held-out validation set, and run a
372/// full forward + backward + AdamW step per training batch.
373///
374/// When `device.is_cuda()`, the `cuda` feature must be compiled in —
375/// otherwise this surfaces a clear error rather than silently falling
376/// back to CPU (GATE-GPUTRAIN-002, contract gpu-training-backend-v1).
377#[allow(clippy::too_many_arguments)]
378fn drive_real(
379    config: PretrainConfig,
380    dataset: &Path,
381    lr: f32,
382    seq_length: usize,
383    batch_size: usize,
384    seed: u64,
385    device: Device,
386    json_output: bool,
387    init_arch: Option<&TransformerConfig>,
388    init_path: Option<&Path>,
389) -> Result<RunStatus> {
390    // GATE-ARCH-370M-011 / INV-ARCH-370M-006 — refuse to dispatch a real
391    // training step when the tokenizer vocab_size and the model vocab_size
392    // disagree. The N-09 OOB escape guard in Embedding::forward masks the
393    // mismatch at runtime → silent garbage gradients otherwise. Synthetic
394    // drive skips this check because it never touches the real model.
395    // Per `apr-pretrain-arch-polymorphic-v1` §qwen_tokenizer_vocab_compatibility
396    // (§50.4 step 5d/5f.4): when --init is set, gate by the EXTRACTED arch's
397    // vocab_size; otherwise gate by the §24/§25 baseline Llama370MConfig::VOCAB_SIZE,
398    // preserving regression-free behavior (FALSIFY-002 + FALSIFY-005 + FALSIFY-006).
399    let target_vocab = init_arch
400        .map(|cfg| cfg.vocab_size)
401        .unwrap_or(Llama370MConfig::VOCAB_SIZE);
402    preflight_tokenizer_vocab_matches_target(
403        &config.tokenizer_dir,
404        target_vocab,
405        init_arch.is_some(),
406    )?;
407
408    // MVP: pad_id/eos_id both 0. All sequences are uniform length
409    // (seq_length + 1) so LMBatch::from_sequences takes the shared
410    // layout path and pad_id is never used for padding. The real
411    // tokenizer's special-token ids will plumb through in a follow-up.
412    //
413    // wrap_around=true: when the corpus shards are exhausted before
414    // --num-steps is reached, reset cursor to shard 0 and continue.
415    // This is standard ML-training behaviour (matches PyTorch /
416    // HuggingFace). Without it, an 18M-token corpus exhausts in ~2
417    // epochs of a 5K-step run with batch=16 seq=512, and the
418    // Cuda*StepFn falls back to placeholder loss `(1.0, 1.0)` — silently
419    // producing garbage gradients. See spec §22 (PR #1073) for the
420    // root-cause investigation.
421    let mut iter = ShardBatchIter::new(dataset, batch_size, seq_length, 0, 0)
422        .map_err(|e| {
423            CliError::ValidationFailed(format!(
424                "dataset shard iterator init failed: {e} (path={})",
425                dataset.display()
426            ))
427        })?
428        .with_wrap_around(true);
429
430    // Reserve the first `HELD_OUT_BATCHES` batches as the held-out val
431    // set; the remainder feeds RealStepFn.
432    let mut held_out: Vec<LMBatch> = Vec::with_capacity(HELD_OUT_BATCHES);
433    for _ in 0..HELD_OUT_BATCHES {
434        match iter.next() {
435            Some(b) => held_out.push(b),
436            None => break,
437        }
438    }
439    if held_out.is_empty() {
440        return Err(CliError::ValidationFailed(format!(
441            "dataset {} is too small to reserve any held-out batches",
442            dataset.display()
443        )));
444    }
445
446    if device.is_cuda() {
447        // §50.4 step 5f.4 (CPU-only this PR): CUDA path with --init is not
448        // yet wired. The 5f.5 follow-up will add `build_shared_cuda_trainer_with_init`
449        // symmetric to the CPU path. Until then, fail-fast rather than silently
450        // ignore --init on CUDA.
451        //
452        // Per `apr-pretrain-arch-polymorphic-v1` v1.4.0 §FALSIFY-APR-PRETRAIN-INIT-CUDA-001,
453        // the error message is extracted into a `pub(crate) const` so that
454        // a drift-prevention test can pin the citation, the "not yet wired
455        // for --device cuda" phrase, and the 5f.5 follow-up reference
456        // without needing a CUDA-feature build to fire the runtime path.
457        if init_arch.is_some() {
458            return Err(CliError::ValidationFailed(
459                FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG.to_string(),
460            ));
461        }
462        drive_real_cuda(config, iter, held_out, lr, seq_length, seed, json_output)
463    } else {
464        drive_real_cpu(
465            config,
466            iter,
467            held_out,
468            lr,
469            seq_length,
470            seed,
471            json_output,
472            init_arch,
473            init_path,
474        )
475    }
476}
477
478/// CPU backend for `drive_real` — builds a `TransformerTrainer`
479/// (`aprender::Tensor` + trueno SIMD) and wires `RealStepFn` /
480/// `RealValFn` / `AprCheckpointFn`.
481#[allow(clippy::too_many_arguments)]
482fn drive_real_cpu(
483    config: PretrainConfig,
484    iter: entrenar::train::shard_reader::ShardBatchIter,
485    held_out: Vec<LMBatch>,
486    lr: f32,
487    seq_length: usize,
488    seed: u64,
489    json_output: bool,
490    init_arch: Option<&TransformerConfig>,
491    init_path: Option<&Path>,
492) -> Result<RunStatus> {
493    // §50.4 step 5f.4: when --init is set, build the trainer via the
494    // polymorphic builder (extracts arch + loads + populates init tensors).
495    // When --init is absent, use the existing from-scratch baseline builder
496    // so the §24/§25 evidence remains regression-free.
497    let trainer = if init_arch.is_some() || init_path.is_some() {
498        build_shared_trainer_with_init(lr, seq_length, seed, init_arch, init_path)
499            .map_err(CliError::ValidationFailed)?
500    } else {
501        build_shared_trainer(lr, seq_length, seed)
502    };
503    let step_fn = RealStepFn::new(trainer.clone(), Box::new(iter));
504    let val_fn = RealValFn::new(trainer.clone(), held_out);
505    let ckpt: Box<dyn CheckpointFn> = Box::new(AprCheckpointFn::new(
506        trainer,
507        "llama-370m-pretrain",
508        "LlamaForCausalLM",
509    ));
510    run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
511}
512
513/// CUDA backend for `drive_real` — builds a `CudaTransformerTrainer`
514/// and wires `CudaRealStepFn` / `CudaRealValFn` / `CudaAprCheckpointFn`
515/// (task #132 Phase 2, contract gpu-training-backend-v1).
516///
517/// When the `cuda` feature is NOT compiled in, this returns a clear
518/// build-time error so operators who asked for `--device cuda` do not
519/// silently get the CPU path (GATE-GPUTRAIN-002 / FM-GPUTRAIN-SILENT-CPU).
520#[cfg(feature = "cuda")]
521#[allow(clippy::too_many_arguments)]
522fn drive_real_cuda(
523    config: PretrainConfig,
524    iter: entrenar::train::shard_reader::ShardBatchIter,
525    held_out: Vec<LMBatch>,
526    lr: f32,
527    seq_length: usize,
528    seed: u64,
529    json_output: bool,
530) -> Result<RunStatus> {
531    use entrenar::train::pretrain_real_cuda::{
532        build_shared_cuda_trainer, CudaAprCheckpointFn, CudaRealStepFn, CudaRealValFn,
533    };
534    let trainer = build_shared_cuda_trainer(lr, seq_length, seed).map_err(|e| {
535        CliError::ValidationFailed(format!(
536            "GATE-GPUTRAIN-002: CUDA trainer allocation failed: {e}. \
537             See contracts/entrenar/gpu-training-backend-v1.yaml and \
538             memory/feedback_cuda_feature_footgun.md — this path is \
539             only reachable when the binary was built with `--features cuda`.",
540        ))
541    })?;
542    let step_fn = CudaRealStepFn::new(trainer.clone(), Box::new(iter));
543    let val_fn = CudaRealValFn::new(trainer.clone(), held_out);
544    let ckpt: Box<dyn CheckpointFn> = Box::new(CudaAprCheckpointFn::new(
545        trainer,
546        "llama-370m-pretrain",
547        "LlamaForCausalLM",
548    ));
549    run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
550}
551
552/// CUDA backend stub when the `cuda` feature is NOT compiled in.
553///
554/// This is the load-bearing gate that prevents FM-GPUTRAIN-SILENT-CPU:
555/// if a user passes `--device cuda` on an apr binary built without
556/// CUDA support, they see a clear "rebuild with --features cuda" error
557/// rather than a 14-minute CPU run masquerading as GPU training
558/// (task #132 lambda-labs incident, 2026-04-21).
559#[cfg(not(feature = "cuda"))]
560#[allow(clippy::too_many_arguments)]
561fn drive_real_cuda(
562    _config: PretrainConfig,
563    _iter: entrenar::train::shard_reader::ShardBatchIter,
564    _held_out: Vec<LMBatch>,
565    _lr: f32,
566    _seq_length: usize,
567    _seed: u64,
568    _json_output: bool,
569) -> Result<RunStatus> {
570    Err(CliError::ValidationFailed(
571        "GATE-GPUTRAIN-002: --device cuda was requested but this `apr` \
572         binary was built WITHOUT the `cuda` feature. \
573         Rebuild with `cargo build --release --features cuda` or use \
574         `--device cpu`. See memory/feedback_cuda_feature_footgun.md \
575         (contract gpu-training-backend-v1 / task #132 Phase 2)."
576            .into(),
577    ))
578}
579
580/// Shared helper: construct the `PretrainLoop`, run it, print the
581/// terminal report, and bubble the `RunStatus` back for exit-code
582/// mapping. `checkpoint_fn` — when `Some` — writes an APR file per
583/// epoch that passes GATE-TRAIN-005.
584fn run_and_report<S: StepFn, V: ValFn>(
585    config: PretrainConfig,
586    step_fn: S,
587    val_fn: V,
588    checkpoint_fn: Option<Box<dyn CheckpointFn>>,
589    json_output: bool,
590) -> Result<RunStatus> {
591    let mut loop_ = PretrainLoop::new(config, step_fn, val_fn);
592    if let Some(ckpt) = checkpoint_fn {
593        loop_ = loop_.with_checkpoint_fn(ckpt);
594    }
595    let status = loop_.run();
596    report(&status, &loop_, json_output)?;
597    Ok(status)
598}
599
600fn abort_to_err(abort: &PretrainAbort) -> CliError {
601    match abort {
602        PretrainAbort::Divergence { .. } | PretrainAbort::DivergenceAtEpochZero { .. } => {
603            CliError::ValidationFailed(format!(
604                "GATE-TRAIN-005 ship-blocker fired: {abort}. See \
605                 contracts/training-loop-pretrain-v1.yaml and \
606                 memory/project_ship_two_001_model1_qlora_divergence.md"
607            ))
608        }
609        PretrainAbort::NumericalInstability { .. } => {
610            CliError::ValidationFailed(format!("GATE-TRAIN-007 NaN/Inf guard fired: {abort}"))
611        }
612        PretrainAbort::ThroughputOutOfRange { .. } => CliError::ValidationFailed(format!(
613            "GATE-TRAIN-008 throughput-range guard fired: {abort}"
614        )),
615    }
616}
617
618fn print_header(cfg: &PretrainConfig) {
619    output::header("apr pretrain — SHIP-TWO-001 MODEL-2 training loop");
620    println!();
621    output::section("Configuration");
622    output::kv("  Dataset", cfg.dataset_path.display().to_string());
623    output::kv("  Tokenizer", cfg.tokenizer_dir.display().to_string());
624    output::kv("  Run dir", cfg.run_dir.display().to_string());
625    output::kv("  LR max", format!("{:.2e}", cfg.lr_max));
626    output::kv("  Total steps", cfg.total_steps.to_string());
627    output::kv("  Warmup steps", cfg.warmup_steps.to_string());
628    output::kv(
629        "  Batch × seq",
630        format!("{} × {}", cfg.batch_size, cfg.seq_length),
631    );
632    output::kv("  Steps / epoch", cfg.steps_per_epoch.to_string());
633    output::kv("  Seed", cfg.seed.to_string());
634    output::kv("  Target val_loss", format!("{:.2}", cfg.target_val_loss));
635    println!();
636}
637
638fn report<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
639    status: &RunStatus,
640    loop_: &PretrainLoop<S, V>,
641    json_output: bool,
642) -> Result<()> {
643    if json_output {
644        let report = PretrainReport::from(status, loop_);
645        let json = serde_json::to_string_pretty(&report)
646            .map_err(|e| CliError::InvalidFormat(e.to_string()))?;
647        println!("{json}");
648        return Ok(());
649    }
650
651    output::section("Run Result");
652    match status {
653        RunStatus::Ok {
654            final_val_loss,
655            epochs_completed,
656        } => {
657            println!(
658                "  {} CONVERGED  final val_loss={:.4} after {} epoch(s)",
659                "OK".green().bold(),
660                final_val_loss,
661                epochs_completed
662            );
663        }
664        RunStatus::EarlyStop {
665            best_val_loss,
666            epochs_completed,
667        } => {
668            println!(
669                "  {} EARLY_STOP  best val_loss={:.4} after {} epoch(s)",
670                "OK".yellow().bold(),
671                best_val_loss,
672                epochs_completed
673            );
674        }
675        RunStatus::Aborted(abort) => {
676            println!("  {} ABORTED  {}", "FAIL".red().bold(), abort);
677        }
678    }
679    output::kv("  Steps recorded", loop_.step_metrics().len().to_string());
680    output::kv(
681        "  Epochs recorded",
682        loop_.epoch_artifacts().len().to_string(),
683    );
684    println!();
685    Ok(())
686}
687
688#[derive(serde::Serialize)]
689struct PretrainReport {
690    status: String,
691    detail: Option<String>,
692    final_val_loss: Option<f32>,
693    epochs_completed: usize,
694    steps_recorded: usize,
695    val_loss_history: Vec<f32>,
696    /// Per-step `StepMetrics` captured by `PretrainLoop` (GATE-TRAIN-001
697    /// contract `training-loop-pretrain-v1.yaml::per_step_metrics.required`).
698    ///
699    /// Emitted so downstream consumers can discharge FALSIFY-GPUTRAIN-005
700    /// (step-time < 500 ms on RTX 4090 for 370M) and FALSIFY-GPUTRAIN-006
701    /// (same-seed reproducibility — two cuda:0 runs at seed=0 must match
702    /// on every step's train_loss within `AC_GPUTRAIN_006_MAX_SEED_LOSS_DELTA`
703    /// = 1e-5) directly from the `--json` output, rather than having to
704    /// parse run-dir checkpoint metadata.
705    per_step_metrics: Vec<entrenar::train::pretrain::StepMetrics>,
706}
707
708impl PretrainReport {
709    fn from<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
710        status: &RunStatus,
711        loop_: &PretrainLoop<S, V>,
712    ) -> Self {
713        let (status_name, detail, final_val_loss, epochs_completed) = match status {
714            RunStatus::Ok {
715                final_val_loss,
716                epochs_completed,
717            } => (
718                "OK".to_string(),
719                None,
720                Some(*final_val_loss),
721                *epochs_completed,
722            ),
723            RunStatus::EarlyStop {
724                best_val_loss,
725                epochs_completed,
726            } => (
727                "EARLY_STOP".to_string(),
728                None,
729                Some(*best_val_loss),
730                *epochs_completed,
731            ),
732            RunStatus::Aborted(abort) => (
733                "ABORTED".to_string(),
734                Some(abort.to_string()),
735                None,
736                loop_.epoch_artifacts().len(),
737            ),
738        };
739        PretrainReport {
740            status: status_name,
741            detail,
742            final_val_loss,
743            epochs_completed,
744            steps_recorded: loop_.step_metrics().len(),
745            val_loss_history: loop_.val_loss_history().to_vec(),
746            per_step_metrics: loop_.step_metrics().to_vec(),
747        }
748    }
749}
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754    use tempfile::TempDir;
755
756    /// Stage a `vocab.json` with exactly `n` distinct integer-string tokens at
757    /// `<dir>/vocab.json`. Used by pre-flight gate tests + by other tests that
758    /// need to get PAST the GATE-ARCH-370M-011 pre-flight to exercise a later
759    /// failure mode (e.g. empty dataset shards).
760    fn stage_vocab_json(dir: &std::path::Path, n: usize) {
761        std::fs::create_dir_all(dir).expect("mkdir tokenizer dir");
762        let mut obj = serde_json::Map::with_capacity(n);
763        for i in 0..n {
764            obj.insert(format!("t{i}"), serde_json::Value::from(i as u64));
765        }
766        let json = serde_json::to_string(&obj).expect("serialize");
767        std::fs::write(dir.join("vocab.json"), json).expect("write vocab.json");
768    }
769
770    #[test]
771    fn preflight_accepts_matching_vocab() {
772        // GATE-ARCH-370M-011 acceptance case: tokenizer vocab.json with
773        // exactly Llama370MConfig::VOCAB_SIZE entries must pass pre-flight.
774        let tmp = TempDir::new().expect("tempdir");
775        stage_vocab_json(tmp.path(), Llama370MConfig::VOCAB_SIZE);
776        preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
777            .expect("matching vocab must pass GATE-ARCH-370M-011");
778    }
779
780    #[test]
781    fn preflight_rejects_tokenizer_vocab_mismatch() {
782        // FALSIFY-ARCH-370M-011: a tokenizer whose vocab size drifts from
783        // the model's pinned VOCAB_SIZE MUST abort dispatch with an error
784        // message that names both values and the gate id, so the operator
785        // can see the mismatch without stepping through code. Task #131
786        // bumped VOCAB_SIZE to 50_257 (Option A) — the counter-example
787        // below now exercises a tokenizer one token short of contract.
788        let tmp = TempDir::new().expect("tempdir");
789        let mismatch = Llama370MConfig::VOCAB_SIZE - 1;
790        stage_vocab_json(tmp.path(), mismatch);
791        let err =
792            preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
793                .expect_err("tokenizer/model vocab mismatch must be rejected");
794        match err {
795            CliError::ValidationFailed(msg) => {
796                assert!(
797                    msg.contains("GATE-ARCH-370M-011"),
798                    "msg must cite gate: {msg}"
799                );
800                assert!(
801                    msg.contains(&mismatch.to_string()),
802                    "msg must name tokenizer vocab: {msg}"
803                );
804                assert!(
805                    msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
806                    "msg must name model vocab: {msg}"
807                );
808            }
809            other => panic!("unexpected error: {other:?}"),
810        }
811    }
812
813    #[test]
814    fn preflight_rejects_missing_vocab_json() {
815        // Missing vocab.json is a pre-flight failure (not a later shard
816        // error) — the operator should know the tokenizer layout is
817        // wrong, not that the dataset is empty.
818        let tmp = TempDir::new().expect("tempdir");
819        let err =
820            preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
821                .expect_err("missing vocab.json must be rejected");
822        match err {
823            CliError::ValidationFailed(msg) => {
824                assert!(
825                    msg.contains("GATE-ARCH-370M-011"),
826                    "msg must cite gate: {msg}"
827                );
828                assert!(
829                    msg.contains("cannot read"),
830                    "msg must name I/O failure: {msg}"
831                );
832            }
833            other => panic!("unexpected error: {other:?}"),
834        }
835    }
836
837    /// FALSIFY-APR-PRETRAIN-ARCH-005 — a Qwen tokenizer (vocab=151_936) MUST
838    /// pass preflight when the target_vocab_size is the Qwen extracted-arch
839    /// (151_936). Falsifies a regression where preflight would still gate
840    /// against the hardcoded Llama370M vocab.
841    ///
842    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5d.
843    #[test]
844    fn preflight_qwen_vocab_passes_with_qwen_target() {
845        const QWEN2_VOCAB_SIZE: usize = 151_936;
846        let tmp = TempDir::new().expect("tempdir");
847        stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
848        // §50.4 step 5d called this with init=Some semantic (the polymorphic path). Use
849        // init_is_some=true here per §55 relaxed-bound semantics; vocab.len() == target
850        // is still acceptable under <=.
851        preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN2_VOCAB_SIZE, true).expect(
852            "Qwen tokenizer (151_936) MUST pass preflight when target is Qwen-shaped — \
853             this is the load-bearing claim of §49 fine-tune from a Qwen2.5 init checkpoint",
854        );
855    }
856
857    /// FALSIFY-APR-PRETRAIN-ARCH-006 — a Qwen tokenizer (vocab=151_936) MUST
858    /// FAIL preflight when target_vocab_size is the Llama370M baseline
859    /// (50_257). Falsifies the silent-pass class where an operator would
860    /// accidentally pair a Qwen tokenizer with the from-scratch trainer.
861    ///
862    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5d.
863    #[test]
864    fn preflight_qwen_vocab_fails_with_llama_target() {
865        const QWEN2_VOCAB_SIZE: usize = 151_936;
866        let tmp = TempDir::new().expect("tempdir");
867        stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
868        // §55: this is the from-scratch path (init absent), so init_is_some=false.
869        // Strict equality applies; tokenizer (151_936) ≠ target (50_257) MUST fail.
870        let err = preflight_tokenizer_vocab_matches_target(
871            tmp.path(),
872            Llama370MConfig::VOCAB_SIZE,
873            false,
874        )
875        .expect_err(
876            "Qwen tokenizer (151_936) MUST FAIL preflight when target is Llama370M (50_257) — \
877             silent-pass would corrupt training",
878        );
879        match err {
880            CliError::ValidationFailed(msg) => {
881                assert!(
882                    msg.contains(&QWEN2_VOCAB_SIZE.to_string()),
883                    "msg must name Qwen vocab size 151_936: {msg}"
884                );
885                assert!(
886                    msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
887                    "msg must name target Llama vocab size 50_257: {msg}"
888                );
889            }
890            other => panic!("unexpected error: {other:?}"),
891        }
892    }
893
894    /// FALSIFY-APR-PRETRAIN-ARCH-009 (§55) — at preflight level, an HF
895    /// tokenizer with vocab.json count = 151665 (BPE+added, the §54 LIVE
896    /// smoke shape) MUST PASS preflight when target is Qwen 151936 AND
897    /// init_is_some=true (the polymorphic path).
898    #[test]
899    fn preflight_qwen_reserved_slots_pass_under_polymorphic_init() {
900        const QWEN_TOKENIZER_EFFECTIVE: usize = 151_665;
901        const QWEN_DECLARED_VOCAB: usize = 151_936;
902        let tmp = TempDir::new().expect("tempdir");
903        stage_vocab_json(tmp.path(), QWEN_TOKENIZER_EFFECTIVE);
904
905        // init_is_some=true: relaxed bound applies; 151665 ≤ 151936 PASSES.
906        preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, true).expect(
907            "FALSIFY-APR-PRETRAIN-ARCH-009: HF reserved-slot tokenizer (151_665 ≤ 151_936) \
908             MUST pass preflight under polymorphic init path (§55 relaxed bound)",
909        );
910
911        // init_is_some=false: strict equality applies; 151665 ≠ 151936 FAILS.
912        let err =
913            preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, false)
914                .expect_err(
915                    "FALSIFY-APR-PRETRAIN-ARCH-009 dual: from-scratch path MUST keep strict ==",
916                );
917        match err {
918            CliError::ValidationFailed(msg) => {
919                assert!(
920                    msg.contains("GATE-ARCH-370M-011")
921                        && msg.contains(&QWEN_TOKENIZER_EFFECTIVE.to_string())
922                        && msg.contains(&QWEN_DECLARED_VOCAB.to_string()),
923                    "strict-mode error must name gate + both sizes: {msg}"
924                );
925            }
926            other => panic!("unexpected error: {other:?}"),
927        }
928    }
929
930    /// FALSIFY-APR-PRETRAIN-ARCH-010 (§55) — at preflight level, a tokenizer
931    /// with MORE entries than the model declares MUST FAIL even under the
932    /// polymorphic init path. This is the OOB-safety guard: such a tokenizer
933    /// could emit ids ≥ model_vocab → silent embedding-lookup garbage.
934    #[test]
935    fn preflight_oversized_tokenizer_rejected_even_under_polymorphic_init() {
936        const QWEN_DECLARED_VOCAB: usize = 151_936;
937        let oversized = QWEN_DECLARED_VOCAB + 100;
938        let tmp = TempDir::new().expect("tempdir");
939        stage_vocab_json(tmp.path(), oversized);
940
941        let err = preflight_tokenizer_vocab_matches_target(
942            tmp.path(),
943            QWEN_DECLARED_VOCAB,
944            true, // polymorphic path
945        )
946        .expect_err(
947            "FALSIFY-APR-PRETRAIN-ARCH-010: oversized tokenizer MUST fail-fast even under \
948             polymorphic init (OOB safety; relaxed bound is ≤ not <)",
949        );
950        match err {
951            CliError::ValidationFailed(msg) => {
952                assert!(
953                    msg.contains("RELAXED") && msg.contains("OOB"),
954                    "polymorphic-mode error must cite RELAXED + OOB: {msg}"
955                );
956            }
957            other => panic!("unexpected error: {other:?}"),
958        }
959    }
960
961    /// FALSIFY-APR-PRETRAIN-INIT-CUDA-001 (drift-prevention): the
962    /// fail-fast error message returned when `--init` is paired with
963    /// `--device cuda` (before the §50.4 step 5f.5 wireup lands) MUST
964    /// contain (a) the falsifier id, (b) the "not yet wired for --device
965    /// cuda" phrase, and (c) the 5f.5 follow-up reference.
966    ///
967    /// Pinned via `pub(crate) const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG`
968    /// so this test fires on a CPU-only build (no `--features cuda` needed).
969    /// If a future refactor renames or rephrases the error, this test
970    /// catches the drift before the contract reference goes stale.
971    ///
972    /// Promotion to LIVE-INTEGRATION requires §50.4 step 5f.5 LIVE
973    /// (CUDA wireup landed + GPU smoke confirms `apr pretrain --init
974    /// <PATH> --device cuda` actually trains). Until then, this test
975    /// pins the safety guard.
976    #[test]
977    fn drive_real_cuda_init_path_fail_fasts_with_falsifier_citation() {
978        let msg = FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG;
979        assert!(
980            msg.contains("FALSIFY-APR-PRETRAIN-INIT-CUDA-001"),
981            "error message MUST cite the falsifier id (auditability): {msg}"
982        );
983        assert!(
984            msg.contains("not yet wired for --device cuda"),
985            "error message MUST contain the canonical 'not yet wired' \
986             phrase so operators recognize the §50.4 step 5f.5 gap: {msg}"
987        );
988        assert!(
989            msg.contains("step 5f.5 follow-up"),
990            "error message MUST reference the 5f.5 follow-up so future \
991             agents know which step retires this guard: {msg}"
992        );
993        assert!(
994            msg.contains("--device cpu") && msg.contains("OR omit --init"),
995            "error message MUST suggest both workarounds (CPU device OR \
996             omit --init for from-scratch CUDA): {msg}"
997        );
998    }
999
1000    #[test]
1001    fn synthetic_pretrain_end_to_end_happy_path() {
1002        let tmp = TempDir::new().expect("tempdir");
1003        let dataset = tmp.path().join("data.jsonl");
1004        let tokenizer = tmp.path().join("tok");
1005        let run_dir = tmp.path().join("run");
1006
1007        let result = run(
1008            &dataset,
1009            &tokenizer,
1010            &run_dir,
1011            PretrainMode::Finetune,
1012            Some(5.0e-5),
1013            25,
1014            Some(5),
1015            2,
1016            4,
1017            5,
1018            42,
1019            Some(2.2),
1020            50257,
1021            true,
1022            "cpu",
1023            None,
1024            true,
1025        );
1026        assert!(
1027            result.is_ok(),
1028            "synthetic pretrain end-to-end must succeed: got {result:?}"
1029        );
1030    }
1031
1032    #[test]
1033    fn real_mode_empty_dataset_dir_errors() {
1034        // When --synthetic is off, the real-corpus branch must surface a
1035        // clear error if the dataset directory has no .bin shards. This
1036        // supersedes the old "non-synthetic is not implemented" guard.
1037        // Stage a valid vocab.json first so GATE-ARCH-370M-011 pre-flight
1038        // passes — otherwise the shard-iterator error below is never reached.
1039        let tmp = TempDir::new().expect("tempdir");
1040        let tok_dir = tmp.path().join("tok");
1041        stage_vocab_json(&tok_dir, Llama370MConfig::VOCAB_SIZE);
1042        let err = run(
1043            tmp.path(),
1044            &tok_dir,
1045            tmp.path(),
1046            PretrainMode::Finetune,
1047            Some(5.0e-5),
1048            10,
1049            Some(2),
1050            2,
1051            4,
1052            5,
1053            42,
1054            Some(2.2),
1055            50257,
1056            false,
1057            "cpu",
1058            None,
1059            true,
1060        )
1061        .expect_err("empty dataset dir must fail to initialise the shard iterator");
1062        match err {
1063            CliError::ValidationFailed(msg) => {
1064                assert!(
1065                    msg.contains("shard iterator init failed"),
1066                    "unexpected message: {msg}"
1067                );
1068            }
1069            other => panic!("unexpected error: {other:?}"),
1070        }
1071    }
1072
1073    #[test]
1074    fn invalid_target_val_loss_rejected() {
1075        let tmp = TempDir::new().expect("tempdir");
1076        let err = run(
1077            tmp.path(),
1078            tmp.path(),
1079            tmp.path(),
1080            PretrainMode::Finetune,
1081            Some(5.0e-5),
1082            10,
1083            Some(2),
1084            2,
1085            4,
1086            5,
1087            42,
1088            Some(-1.0),
1089            50257,
1090            true,
1091            "cpu",
1092            None,
1093            true,
1094        )
1095        .expect_err("negative target_val_loss must be rejected");
1096        assert!(matches!(err, CliError::ValidationFailed(_)));
1097    }
1098
1099    // ── GATE-TRAIN-009 / INV-TRAIN-009 falsifiers ──────────────────────
1100    // Contract: training-loop-pretrain-v1 v1.3.0 §hyperparameter_defaults
1101    //
1102    // These tests bind the CLI's `mode_defaults` resolver to the
1103    // hyperparameter_defaults YAML table. If the table is ever edited
1104    // without also updating this resolver (or vice versa), the tests
1105    // fail. That is exactly the drift INV-TRAIN-009 forbids.
1106
1107    #[test]
1108    fn mode_finetune_is_default_and_matches_contract() {
1109        // No overrides → resolved HP matches the `finetune` YAML row
1110        // (lr_max=5e-5, warmup_steps=100, target_val_loss=2.2) AND the
1111        // regime is Finetune so INV-TRAIN-005 epoch-zero cap = 10.0.
1112        let hp = mode_defaults(PretrainMode::Finetune, 50257, None, None, None);
1113        assert_eq!(hp.regime, TrainingRegime::Finetune);
1114        assert!(
1115            (hp.lr_max - 5.0e-5).abs() < 1.0e-12,
1116            "lr_max={} must equal finetune default 5e-5",
1117            hp.lr_max
1118        );
1119        assert_eq!(hp.warmup_steps, 100);
1120        assert!(
1121            (hp.target_val_loss - 2.2).abs() < 1.0e-6,
1122            "target_val_loss={} must equal finetune default 2.2",
1123            hp.target_val_loss
1124        );
1125    }
1126
1127    #[test]
1128    fn mode_from_scratch_applies_all_four_defaults() {
1129        // `--mode from-scratch` with no HP overrides MUST yield the full
1130        // cold-start 4-tuple atomically — regime=FromScratch, lr=3e-4,
1131        // warmup=1000, target=3.0. INV-TRAIN-009 falsifier (a).
1132        let hp = mode_defaults(PretrainMode::FromScratch, 50257, None, None, None);
1133        assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1134        assert!(
1135            (hp.lr_max - 3.0e-4).abs() < 1.0e-12,
1136            "lr_max={} must equal from_scratch default 3e-4",
1137            hp.lr_max
1138        );
1139        assert_eq!(hp.warmup_steps, 1000);
1140        assert!(
1141            (hp.target_val_loss - 3.0).abs() < 1.0e-6,
1142            "target_val_loss={} must equal from_scratch default 3.0",
1143            hp.target_val_loss
1144        );
1145    }
1146
1147    #[test]
1148    fn mode_from_scratch_honors_explicit_lr_override() {
1149        // `--mode from-scratch --lr 1e-4` → regime still flips to
1150        // FromScratch AND warmup/target keep the from_scratch defaults,
1151        // but lr_max is the operator-supplied 1e-4. INV-TRAIN-009
1152        // falsifier (b): overrides win, regime still moves.
1153        let hp = mode_defaults(PretrainMode::FromScratch, 50257, Some(1.0e-4), None, None);
1154        assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1155        assert!(
1156            (hp.lr_max - 1.0e-4).abs() < 1.0e-12,
1157            "lr_max={} must equal explicit override 1e-4",
1158            hp.lr_max
1159        );
1160        // Remaining two fields retained their mode defaults.
1161        assert_eq!(hp.warmup_steps, 1000);
1162        assert!((hp.target_val_loss - 3.0).abs() < 1.0e-6);
1163    }
1164
1165    // ── GATE-TRAIN-010 / INV-TRAIN-010 falsifiers ──────────────────────
1166    // Contract: training-loop-pretrain-v1 v1.4.0 §INV-TRAIN-010
1167    //
1168    // Task #105's original wiring shipped `synthetic: bool` with
1169    // `default_value = "true"`. The `--synthetic` flag had no
1170    // companion to turn it off, so every invocation of `apr pretrain`
1171    // silently routed to drive_synthetic. Tasks #119 / #124 / #125
1172    // all captured scripted-loss output and mis-labeled it real
1173    // compute. These two tests parse actual argv through clap and
1174    // assert the routing discriminator byte-for-byte.
1175
1176    fn parse_pretrain_synthetic(extra: &[&str]) -> bool {
1177        // The `Commands` enum is large enough in debug builds to overflow
1178        // the default 2 MiB test-thread stack during clap's recursive
1179        // destructuring. Run the parse on a worker thread with a 16 MiB
1180        // stack so this falsifier passes in both debug and release.
1181        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1182        std::thread::Builder::new()
1183            .stack_size(16 * 1024 * 1024)
1184            .spawn(move || {
1185                use clap::Parser;
1186                let mut argv: Vec<String> = vec![
1187                    "apr".to_string(),
1188                    "pretrain".to_string(),
1189                    "--dataset".to_string(),
1190                    "/tmp/_gate_train_010/ds".to_string(),
1191                    "--tokenizer".to_string(),
1192                    "/tmp/_gate_train_010/tok".to_string(),
1193                    "--run-dir".to_string(),
1194                    "/tmp/_gate_train_010/run".to_string(),
1195                ];
1196                argv.extend(extra);
1197                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1198                match *cli.command {
1199                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1200                        synthetic,
1201                        ..
1202                    }) => synthetic,
1203                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1204                }
1205            })
1206            .expect("spawn parse thread")
1207            .join()
1208            .expect("parse thread must not panic")
1209    }
1210
1211    #[test]
1212    fn cli_pretrain_defaults_to_real_compute() {
1213        // Absent `--synthetic` MUST parse to synthetic=false so the
1214        // dispatcher routes through drive_real.
1215        assert!(
1216            !parse_pretrain_synthetic(&[]),
1217            "INV-TRAIN-010: `apr pretrain` (no --synthetic) must parse to synthetic=false"
1218        );
1219    }
1220
1221    #[test]
1222    fn cli_pretrain_synthetic_flag_routes_to_synthetic() {
1223        // `--synthetic` present MUST parse to synthetic=true.
1224        assert!(
1225            parse_pretrain_synthetic(&["--synthetic"]),
1226            "INV-TRAIN-010: `apr pretrain --synthetic` must parse to synthetic=true"
1227        );
1228    }
1229
1230    // ── FALSIFY-GPUTRAIN-001 / 002 CLI surface (contract phase 1) ────
1231    // Contract: gpu-training-backend-v1 §device_dispatch
1232    //
1233    // These tests parse actual `apr pretrain --device …` argv through
1234    // clap and assert the string is surfaced byte-for-byte to the
1235    // dispatcher. `resolve_device()` itself is exercised by
1236    // `aprender-train::train::device::tests` — these tests verify that
1237    // the CLI flag exists and that its default is `auto` (the only
1238    // spec allowed to fall back).
1239
1240    fn parse_pretrain_device(extra: &[&str]) -> String {
1241        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1242        std::thread::Builder::new()
1243            .stack_size(16 * 1024 * 1024)
1244            .spawn(move || {
1245                use clap::Parser;
1246                let mut argv: Vec<String> = vec![
1247                    "apr".to_string(),
1248                    "pretrain".to_string(),
1249                    "--dataset".to_string(),
1250                    "/tmp/_gputrain_device/ds".to_string(),
1251                    "--tokenizer".to_string(),
1252                    "/tmp/_gputrain_device/tok".to_string(),
1253                    "--run-dir".to_string(),
1254                    "/tmp/_gputrain_device/run".to_string(),
1255                ];
1256                argv.extend(extra);
1257                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1258                match *cli.command {
1259                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1260                        device, ..
1261                    }) => device,
1262                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1263                }
1264            })
1265            .expect("spawn parse thread")
1266            .join()
1267            .expect("parse thread must not panic")
1268    }
1269
1270    #[test]
1271    fn cli_pretrain_device_defaults_to_auto() {
1272        // Absent `--device`, the flag MUST parse to `"auto"` — the only
1273        // spec allowed to silently fall back to CPU when CUDA is not
1274        // available. Any other default would violate the contract's
1275        // "explicit request → hard-fail" invariant.
1276        assert_eq!(
1277            parse_pretrain_device(&[]),
1278            "auto",
1279            "gpu-training-backend-v1 INV-GPUTRAIN-002: default --device must be `auto`",
1280        );
1281    }
1282
1283    #[test]
1284    fn cli_pretrain_device_accepts_cpu() {
1285        // `--device cpu` MUST round-trip through clap unchanged.
1286        assert_eq!(parse_pretrain_device(&["--device", "cpu"]), "cpu");
1287    }
1288
1289    #[test]
1290    fn cli_pretrain_device_accepts_cuda_index() {
1291        // `--device cuda:7` MUST round-trip unchanged; grammar
1292        // enforcement happens in `resolve_device`, not at clap.
1293        assert_eq!(parse_pretrain_device(&["--device", "cuda:7"]), "cuda:7");
1294    }
1295
1296    // ── apr-pretrain-from-init-v1 falsifiers ────────────────────────────
1297    // Contract: contracts/apr-pretrain-from-init-v1.yaml v1.0.0 PROPOSED
1298    // Spec: SPEC-SHIP-TWO-001 §49 step 4 — wire `apr pretrain --init`
1299    //
1300    // PARTIAL_ALGORITHM_LEVEL: file-existence + magic-byte checks bind
1301    // FALSIFY-APR-PRETRAIN-INIT-003 / -004; the clap surface binds
1302    // FALSIFY-001 / -007. FALSIFY-005 (arch mismatch), -006 (init_loss
1303    // signal), -009 (optimizer state), -010 (idempotent load) are gated
1304    // on the §49 step 5 weight-load impl. The "valid APR returns
1305    // not-yet-wired" test pins the no-silent-fallback contract: a
1306    // recognised APR cannot be silently ignored.
1307
1308    fn parse_pretrain_init(extra: &[&str]) -> Option<std::path::PathBuf> {
1309        let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1310        std::thread::Builder::new()
1311            .stack_size(16 * 1024 * 1024)
1312            .spawn(move || {
1313                use clap::Parser;
1314                let mut argv: Vec<String> = vec![
1315                    "apr".to_string(),
1316                    "pretrain".to_string(),
1317                    "--dataset".to_string(),
1318                    "/tmp/_init_flag/ds".to_string(),
1319                    "--tokenizer".to_string(),
1320                    "/tmp/_init_flag/tok".to_string(),
1321                    "--run-dir".to_string(),
1322                    "/tmp/_init_flag/run".to_string(),
1323                ];
1324                argv.extend(extra);
1325                let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1326                match *cli.command {
1327                    crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1328                        init, ..
1329                    }) => init,
1330                    other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1331                }
1332            })
1333            .expect("spawn parse thread")
1334            .join()
1335            .expect("parse thread must not panic")
1336    }
1337
1338    /// FALSIFY-APR-PRETRAIN-INIT-001: --init flag exists in clap surface.
1339    #[test]
1340    fn pretrain_init_flag_absent_parses_to_none() {
1341        // Absent --init MUST parse to None. Falsifies a regression where a
1342        // default value silently injects a path the operator never typed.
1343        assert_eq!(
1344            parse_pretrain_init(&[]),
1345            None,
1346            "FALSIFY-APR-PRETRAIN-INIT-001/002: default --init must be None (no silent default)"
1347        );
1348    }
1349
1350    /// FALSIFY-APR-PRETRAIN-INIT-001: --init <PATH> parses to Some(PathBuf).
1351    #[test]
1352    fn pretrain_init_flag_parses_path() {
1353        let parsed = parse_pretrain_init(&["--init", "/tmp/foo.apr"]);
1354        assert_eq!(
1355            parsed.as_deref().and_then(|p| p.to_str()),
1356            Some("/tmp/foo.apr"),
1357            "FALSIFY-APR-PRETRAIN-INIT-001: --init <PATH> must round-trip through clap"
1358        );
1359    }
1360
1361    /// FALSIFY-APR-PRETRAIN-INIT-003: --init <missing-file> fails fast
1362    /// before any trainer allocation; stderr names the path.
1363    #[test]
1364    fn pretrain_init_missing_file_errors() {
1365        let tmp = TempDir::new().expect("tempdir");
1366        let missing = tmp.path().join("does-not-exist.apr");
1367        let err = run(
1368            tmp.path(),
1369            tmp.path(),
1370            tmp.path(),
1371            PretrainMode::Finetune,
1372            Some(5.0e-5),
1373            10,
1374            Some(2),
1375            2,
1376            4,
1377            5,
1378            42,
1379            Some(2.2),
1380            50257,
1381            true,
1382            "cpu",
1383            Some(&missing),
1384            true,
1385        )
1386        .expect_err("missing --init file must be rejected");
1387        match err {
1388            CliError::ValidationFailed(msg) => {
1389                assert!(
1390                    msg.contains("FALSIFY-APR-PRETRAIN-INIT-003"),
1391                    "msg must cite falsifier id: {msg}"
1392                );
1393                assert!(
1394                    msg.contains("does-not-exist.apr"),
1395                    "msg must name the missing path: {msg}"
1396                );
1397            }
1398            other => panic!("unexpected error: {other:?}"),
1399        }
1400    }
1401
1402    /// FALSIFY-APR-PRETRAIN-INIT-004: --init with wrong magic bytes fails fast.
1403    #[test]
1404    fn pretrain_init_bad_magic_errors() {
1405        let tmp = TempDir::new().expect("tempdir");
1406        let bad = tmp.path().join("not-an-apr.bin");
1407        std::fs::write(&bad, b"GGUF\x00\x00\x00\x00\x00\x00\x00\x00")
1408            .expect("write fixture file");
1409        let err = run(
1410            tmp.path(),
1411            tmp.path(),
1412            tmp.path(),
1413            PretrainMode::Finetune,
1414            Some(5.0e-5),
1415            10,
1416            Some(2),
1417            2,
1418            4,
1419            5,
1420            42,
1421            Some(2.2),
1422            50257,
1423            true,
1424            "cpu",
1425            Some(&bad),
1426            true,
1427        )
1428        .expect_err("invalid magic bytes must be rejected");
1429        match err {
1430            CliError::ValidationFailed(msg) => {
1431                assert!(
1432                    msg.contains("FALSIFY-APR-PRETRAIN-INIT-004"),
1433                    "msg must cite falsifier id: {msg}"
1434                );
1435                assert!(
1436                    msg.contains("not a valid APR file"),
1437                    "msg must describe magic mismatch: {msg}"
1438                );
1439            }
1440            other => panic!("unexpected error: {other:?}"),
1441        }
1442    }
1443
1444    /// FALSIFY-APR-PRETRAIN-INIT-004: empty file (read_exact fails on 4 bytes).
1445    #[test]
1446    fn pretrain_init_empty_file_errors() {
1447        let tmp = TempDir::new().expect("tempdir");
1448        let empty = tmp.path().join("empty.apr");
1449        std::fs::write(&empty, b"").expect("write empty fixture");
1450        let err = run(
1451            tmp.path(),
1452            tmp.path(),
1453            tmp.path(),
1454            PretrainMode::Finetune,
1455            Some(5.0e-5),
1456            10,
1457            Some(2),
1458            2,
1459            4,
1460            5,
1461            42,
1462            Some(2.2),
1463            50257,
1464            true,
1465            "cpu",
1466            Some(&empty),
1467            true,
1468        )
1469        .expect_err("empty file must be rejected (cannot contain magic bytes)");
1470        assert!(matches!(err, CliError::ValidationFailed(_)));
1471    }
1472
1473    /// §50.4 step 5f.4: a magic-byte-valid but metadata-bogus APR file
1474    /// MUST be rejected at the architecture-extraction step, not silently
1475    /// fall back to random init. The error must clearly cite the
1476    /// architecture-extraction failure (not the legacy "not yet wired"
1477    /// guard, which was retired when the wireup landed). This drift-prevention
1478    /// pins the new fail-closed semantic.
1479    #[test]
1480    fn pretrain_init_valid_magic_but_bogus_metadata_fails_at_arch_extraction() {
1481        let tmp = TempDir::new().expect("tempdir");
1482        let valid = tmp.path().join("v2-valid-magic-bogus-metadata.apr");
1483        // APR\0 magic + padding; passes validate_init_apr_path but
1484        // read_apr_architecture (which reads the v2 header) will return None.
1485        std::fs::write(&valid, b"APR\x00\x00\x00\x00\x00\x00\x00\x00\x00")
1486            .expect("write fixture file");
1487        let err = run(
1488            tmp.path(),
1489            tmp.path(),
1490            tmp.path(),
1491            PretrainMode::Finetune,
1492            Some(5.0e-5),
1493            10,
1494            Some(2),
1495            2,
1496            4,
1497            5,
1498            42,
1499            Some(2.2),
1500            50257,
1501            true,
1502            "cpu",
1503            Some(&valid),
1504            true,
1505        )
1506        .expect_err("bogus metadata must NOT silently random-init");
1507        match err {
1508            CliError::ValidationFailed(msg) => {
1509                assert!(
1510                    !msg.contains("not yet wired"),
1511                    "the legacy step-5-partial guard must be retired: {msg}"
1512                );
1513                // The actual error from read_apr_architecture failure or
1514                // downstream layer; both are acceptable as long as we DON'T
1515                // silently load random init.
1516            }
1517            other => panic!("unexpected error: {other:?}"),
1518        }
1519    }
1520
1521    /// Pin v1 magic (APRN) acceptance — `validate_init_apr_path` alone
1522    /// (decoupled from architecture extraction) returns Ok for both APR\0
1523    /// and APRN magic bytes. Architecture extraction is a separate step.
1524    #[test]
1525    fn pretrain_init_v1_magic_aprn_passes_validate_init_apr_path() {
1526        let tmp = TempDir::new().expect("tempdir");
1527        let v1 = tmp.path().join("v1-aprn.apr");
1528        std::fs::write(&v1, b"APRN\x00\x00\x00\x00").expect("write fixture file");
1529        let result = validate_init_apr_path(&v1);
1530        assert!(
1531            result.is_ok(),
1532            "APRN magic must pass validate_init_apr_path; got {result:?}"
1533        );
1534    }
1535}