1use crate::error::{CliError, Result};
16use crate::output;
17use clap::ValueEnum;
18use colored::Colorize;
19use entrenar::models::llama_370m::{
20 assert_tokenizer_vocab_matches_model, assert_tokenizer_vocab_within_model_bound,
21 Llama370MConfig,
22};
23use entrenar::train::device::{resolve_device, Device};
24use entrenar::train::pretrain::{
25 CheckpointFn, LinearDecaySynthetic, PretrainAbort, PretrainConfig, PretrainLoop, RunStatus,
26 ScriptedVal, StepFn, TrainingRegime, ValFn,
27};
28use entrenar::train::pretrain_real::{
29 build_shared_trainer, build_shared_trainer_with_init, AprCheckpointFn, RealStepFn, RealValFn,
30};
31use entrenar::transformer::TransformerConfig;
32use entrenar::train::shard_reader::ShardBatchIter;
33use entrenar::train::transformer_trainer::LMBatch;
34use std::path::Path;
35
36const HELD_OUT_BATCHES: usize = 16;
49
50pub(crate) const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG: &str =
60 "FALSIFY-APR-PRETRAIN-INIT-CUDA-001: --init is not yet wired for --device cuda \
61 (step 5f.5 follow-up); use --device cpu OR omit --init for from-scratch CUDA training.";
62
63#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
68pub enum PretrainMode {
69 Finetune,
71 FromScratch,
73}
74
75#[derive(Clone, Debug, PartialEq)]
80pub(crate) struct ResolvedHp {
81 pub regime: TrainingRegime,
82 pub lr_max: f32,
83 pub warmup_steps: usize,
84 pub target_val_loss: f32,
85}
86
87pub(crate) fn mode_defaults(
88 mode: PretrainMode,
89 vocab_size: u32,
90 lr_override: Option<f32>,
91 warmup_override: Option<usize>,
92 target_override: Option<f32>,
93) -> ResolvedHp {
94 let (regime, lr_def, warmup_def, target_def) = match mode {
95 PretrainMode::Finetune => (TrainingRegime::Finetune, 5.0e-5, 100, 2.2),
96 PretrainMode::FromScratch => (
97 TrainingRegime::FromScratch { vocab_size },
98 3.0e-4,
99 1000,
100 3.0,
101 ),
102 };
103 ResolvedHp {
104 regime,
105 lr_max: lr_override.unwrap_or(lr_def),
106 warmup_steps: warmup_override.unwrap_or(warmup_def),
107 target_val_loss: target_override.unwrap_or(target_def),
108 }
109}
110
111#[allow(clippy::too_many_arguments)]
113pub(crate) fn run(
114 dataset: &Path,
115 tokenizer: &Path,
116 run_dir: &Path,
117 mode: PretrainMode,
118 lr: Option<f32>,
119 num_steps: usize,
120 warmup_steps: Option<usize>,
121 batch_size: usize,
122 seq_length: usize,
123 steps_per_epoch: usize,
124 seed: u64,
125 target_val_loss: Option<f32>,
126 vocab_size: u32,
127 synthetic: bool,
128 device: &str,
129 init: Option<&Path>,
130 json_output: bool,
131) -> Result<()> {
132 let resolved_device =
138 resolve_device(device).map_err(|e| CliError::ValidationFailed(e.to_string()))?;
139
140 let init_arch: Option<TransformerConfig> = if let Some(init_path) = init {
147 validate_init_apr_path(init_path)?;
148 Some(crate::commands::model_config::read_apr_architecture(init_path).ok_or_else(
149 || {
150 CliError::ValidationFailed(format!(
151 "FALSIFY-APR-PRETRAIN-INIT-005: --init APR file at {} has missing or invalid \
152 architecture metadata (hidden_size, num_heads, num_layers, vocab_size, etc). \
153 Cannot extract TransformerConfig per apr-pretrain-arch-polymorphic-v1 \
154 §arch_extraction_signature.",
155 init_path.display()
156 ))
157 },
158 )?)
159 } else {
160 None
161 };
162
163 let hp = mode_defaults(mode, vocab_size, lr, warmup_steps, target_val_loss);
164
165 if hp.target_val_loss <= 0.0 {
167 return Err(CliError::ValidationFailed(format!(
168 "target_val_loss must be positive, got {}",
169 hp.target_val_loss
170 )));
171 }
172 if num_steps == 0 {
173 return Err(CliError::ValidationFailed(
174 "num_steps must be > 0".to_string(),
175 ));
176 }
177 if steps_per_epoch == 0 {
178 return Err(CliError::ValidationFailed(
179 "steps_per_epoch must be > 0".to_string(),
180 ));
181 }
182
183 let config = PretrainConfig {
184 dataset_path: dataset.to_path_buf(),
185 tokenizer_dir: tokenizer.to_path_buf(),
186 run_dir: run_dir.to_path_buf(),
187 lr_max: hp.lr_max,
188 lr_min: (hp.lr_max * 1.0e-2).max(1.0e-7),
189 warmup_steps: hp.warmup_steps,
190 total_steps: num_steps,
191 batch_size,
192 seq_length,
193 steps_per_epoch,
194 seed,
195 grad_clip: 1.0,
196 weight_decay: 0.01,
197 target_val_loss: hp.target_val_loss,
198 patience_epochs: 5,
206 min_epochs_before_early_stop: 3,
212 regime: hp.regime,
213 };
214
215 if !json_output {
216 print_header(&config);
217 output::kv(" Device", resolved_device.to_string());
222 println!();
223 }
224
225 let status = if synthetic {
226 drive_synthetic(
227 config.clone(),
228 num_steps,
229 steps_per_epoch,
230 hp.target_val_loss,
231 json_output,
232 )?
233 } else {
234 drive_real(
235 config.clone(),
236 dataset,
237 hp.lr_max,
238 seq_length,
239 batch_size,
240 seed,
241 resolved_device,
242 json_output,
243 init_arch.as_ref(),
244 init,
245 )?
246 };
247
248 match status {
251 RunStatus::Aborted(abort) => Err(abort_to_err(&abort)),
252 RunStatus::Ok { .. } | RunStatus::EarlyStop { .. } => Ok(()),
253 }
254}
255
256fn drive_synthetic(
260 config: PretrainConfig,
261 num_steps: usize,
262 steps_per_epoch: usize,
263 target_val_loss: f32,
264 json_output: bool,
265) -> Result<RunStatus> {
266 let step_fn = LinearDecaySynthetic {
267 start_loss: (target_val_loss * 2.0).max(1.5),
268 decay_per_step: (target_val_loss * 0.01).max(1.0e-4),
269 grad_norm: 0.8,
270 };
271 let num_epochs = num_steps.div_ceil(steps_per_epoch);
272 let mut sequence = Vec::with_capacity(num_epochs + 2);
273 let start_val = (target_val_loss * 1.8).max(3.0);
274 for i in 0..(num_epochs + 2) {
275 let t = i as f32 / (num_epochs.max(1) as f32);
276 sequence.push(target_val_loss + (start_val - target_val_loss) * (1.0 - t).max(0.0));
277 }
278 let val_fn = ScriptedVal { sequence };
279 run_and_report(config, step_fn, val_fn, None, json_output)
281}
282
283fn validate_init_apr_path(path: &Path) -> Result<()> {
293 let mut file = std::fs::File::open(path).map_err(|e| {
294 CliError::ValidationFailed(format!(
295 "FALSIFY-APR-PRETRAIN-INIT-003: --init path does not exist or is unreadable: {} ({e})",
296 path.display()
297 ))
298 })?;
299 let mut magic = [0u8; 4];
300 use std::io::Read;
301 file.read_exact(&mut magic).map_err(|e| {
302 CliError::ValidationFailed(format!(
303 "FALSIFY-APR-PRETRAIN-INIT-004: --init file too short to contain APR magic bytes: {} ({e})",
304 path.display()
305 ))
306 })?;
307 const APR_MAGIC_V2: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
311 const APR_MAGIC_V1: [u8; 4] = [0x41, 0x50, 0x52, 0x4E];
312 if magic != APR_MAGIC_V2 && magic != APR_MAGIC_V1 {
313 return Err(CliError::ValidationFailed(format!(
314 "FALSIFY-APR-PRETRAIN-INIT-004: --init file is not a valid APR file (magic={:02X?}, expected {:02X?} or {:02X?}): {}",
315 magic, APR_MAGIC_V2, APR_MAGIC_V1, path.display()
316 )));
317 }
318 Ok(())
319}
320
321fn preflight_tokenizer_vocab_matches_target(
339 tokenizer_dir: &Path,
340 target_vocab_size: usize,
341 init_is_some: bool,
342) -> Result<()> {
343 let vocab_path = tokenizer_dir.join("vocab.json");
344 let vocab_json = std::fs::read_to_string(&vocab_path).map_err(|e| {
345 CliError::ValidationFailed(format!(
346 "GATE-ARCH-370M-011 pre-flight: cannot read {} ({e})",
347 vocab_path.display()
348 ))
349 })?;
350 let vocab: serde_json::Map<String, serde_json::Value> = serde_json::from_str(&vocab_json)
351 .map_err(|e| {
352 CliError::ValidationFailed(format!(
353 "GATE-ARCH-370M-011 pre-flight: {} is not a valid vocab.json: {e}",
354 vocab_path.display()
355 ))
356 })?;
357 if init_is_some {
362 assert_tokenizer_vocab_within_model_bound(vocab.len(), target_vocab_size)
363 .map_err(CliError::ValidationFailed)
364 } else {
365 assert_tokenizer_vocab_matches_model(vocab.len(), target_vocab_size)
366 .map_err(CliError::ValidationFailed)
367 }
368}
369
370#[allow(clippy::too_many_arguments)]
378fn drive_real(
379 config: PretrainConfig,
380 dataset: &Path,
381 lr: f32,
382 seq_length: usize,
383 batch_size: usize,
384 seed: u64,
385 device: Device,
386 json_output: bool,
387 init_arch: Option<&TransformerConfig>,
388 init_path: Option<&Path>,
389) -> Result<RunStatus> {
390 let target_vocab = init_arch
400 .map(|cfg| cfg.vocab_size)
401 .unwrap_or(Llama370MConfig::VOCAB_SIZE);
402 preflight_tokenizer_vocab_matches_target(
403 &config.tokenizer_dir,
404 target_vocab,
405 init_arch.is_some(),
406 )?;
407
408 let mut iter = ShardBatchIter::new(dataset, batch_size, seq_length, 0, 0)
422 .map_err(|e| {
423 CliError::ValidationFailed(format!(
424 "dataset shard iterator init failed: {e} (path={})",
425 dataset.display()
426 ))
427 })?
428 .with_wrap_around(true);
429
430 let mut held_out: Vec<LMBatch> = Vec::with_capacity(HELD_OUT_BATCHES);
433 for _ in 0..HELD_OUT_BATCHES {
434 match iter.next() {
435 Some(b) => held_out.push(b),
436 None => break,
437 }
438 }
439 if held_out.is_empty() {
440 return Err(CliError::ValidationFailed(format!(
441 "dataset {} is too small to reserve any held-out batches",
442 dataset.display()
443 )));
444 }
445
446 if device.is_cuda() {
447 if init_arch.is_some() {
458 return Err(CliError::ValidationFailed(
459 FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG.to_string(),
460 ));
461 }
462 drive_real_cuda(config, iter, held_out, lr, seq_length, seed, json_output)
463 } else {
464 drive_real_cpu(
465 config,
466 iter,
467 held_out,
468 lr,
469 seq_length,
470 seed,
471 json_output,
472 init_arch,
473 init_path,
474 )
475 }
476}
477
478#[allow(clippy::too_many_arguments)]
482fn drive_real_cpu(
483 config: PretrainConfig,
484 iter: entrenar::train::shard_reader::ShardBatchIter,
485 held_out: Vec<LMBatch>,
486 lr: f32,
487 seq_length: usize,
488 seed: u64,
489 json_output: bool,
490 init_arch: Option<&TransformerConfig>,
491 init_path: Option<&Path>,
492) -> Result<RunStatus> {
493 let trainer = if init_arch.is_some() || init_path.is_some() {
498 build_shared_trainer_with_init(lr, seq_length, seed, init_arch, init_path)
499 .map_err(CliError::ValidationFailed)?
500 } else {
501 build_shared_trainer(lr, seq_length, seed)
502 };
503 let step_fn = RealStepFn::new(trainer.clone(), Box::new(iter));
504 let val_fn = RealValFn::new(trainer.clone(), held_out);
505 let ckpt: Box<dyn CheckpointFn> = Box::new(AprCheckpointFn::new(
506 trainer,
507 "llama-370m-pretrain",
508 "LlamaForCausalLM",
509 ));
510 run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
511}
512
513#[cfg(feature = "cuda")]
521#[allow(clippy::too_many_arguments)]
522fn drive_real_cuda(
523 config: PretrainConfig,
524 iter: entrenar::train::shard_reader::ShardBatchIter,
525 held_out: Vec<LMBatch>,
526 lr: f32,
527 seq_length: usize,
528 seed: u64,
529 json_output: bool,
530) -> Result<RunStatus> {
531 use entrenar::train::pretrain_real_cuda::{
532 build_shared_cuda_trainer, CudaAprCheckpointFn, CudaRealStepFn, CudaRealValFn,
533 };
534 let trainer = build_shared_cuda_trainer(lr, seq_length, seed).map_err(|e| {
535 CliError::ValidationFailed(format!(
536 "GATE-GPUTRAIN-002: CUDA trainer allocation failed: {e}. \
537 See contracts/entrenar/gpu-training-backend-v1.yaml and \
538 memory/feedback_cuda_feature_footgun.md — this path is \
539 only reachable when the binary was built with `--features cuda`.",
540 ))
541 })?;
542 let step_fn = CudaRealStepFn::new(trainer.clone(), Box::new(iter));
543 let val_fn = CudaRealValFn::new(trainer.clone(), held_out);
544 let ckpt: Box<dyn CheckpointFn> = Box::new(CudaAprCheckpointFn::new(
545 trainer,
546 "llama-370m-pretrain",
547 "LlamaForCausalLM",
548 ));
549 run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
550}
551
552#[cfg(not(feature = "cuda"))]
560#[allow(clippy::too_many_arguments)]
561fn drive_real_cuda(
562 _config: PretrainConfig,
563 _iter: entrenar::train::shard_reader::ShardBatchIter,
564 _held_out: Vec<LMBatch>,
565 _lr: f32,
566 _seq_length: usize,
567 _seed: u64,
568 _json_output: bool,
569) -> Result<RunStatus> {
570 Err(CliError::ValidationFailed(
571 "GATE-GPUTRAIN-002: --device cuda was requested but this `apr` \
572 binary was built WITHOUT the `cuda` feature. \
573 Rebuild with `cargo build --release --features cuda` or use \
574 `--device cpu`. See memory/feedback_cuda_feature_footgun.md \
575 (contract gpu-training-backend-v1 / task #132 Phase 2)."
576 .into(),
577 ))
578}
579
580fn run_and_report<S: StepFn, V: ValFn>(
585 config: PretrainConfig,
586 step_fn: S,
587 val_fn: V,
588 checkpoint_fn: Option<Box<dyn CheckpointFn>>,
589 json_output: bool,
590) -> Result<RunStatus> {
591 let mut loop_ = PretrainLoop::new(config, step_fn, val_fn);
592 if let Some(ckpt) = checkpoint_fn {
593 loop_ = loop_.with_checkpoint_fn(ckpt);
594 }
595 let status = loop_.run();
596 report(&status, &loop_, json_output)?;
597 Ok(status)
598}
599
600fn abort_to_err(abort: &PretrainAbort) -> CliError {
601 match abort {
602 PretrainAbort::Divergence { .. } | PretrainAbort::DivergenceAtEpochZero { .. } => {
603 CliError::ValidationFailed(format!(
604 "GATE-TRAIN-005 ship-blocker fired: {abort}. See \
605 contracts/training-loop-pretrain-v1.yaml and \
606 memory/project_ship_two_001_model1_qlora_divergence.md"
607 ))
608 }
609 PretrainAbort::NumericalInstability { .. } => {
610 CliError::ValidationFailed(format!("GATE-TRAIN-007 NaN/Inf guard fired: {abort}"))
611 }
612 PretrainAbort::ThroughputOutOfRange { .. } => CliError::ValidationFailed(format!(
613 "GATE-TRAIN-008 throughput-range guard fired: {abort}"
614 )),
615 }
616}
617
618fn print_header(cfg: &PretrainConfig) {
619 output::header("apr pretrain — SHIP-TWO-001 MODEL-2 training loop");
620 println!();
621 output::section("Configuration");
622 output::kv(" Dataset", cfg.dataset_path.display().to_string());
623 output::kv(" Tokenizer", cfg.tokenizer_dir.display().to_string());
624 output::kv(" Run dir", cfg.run_dir.display().to_string());
625 output::kv(" LR max", format!("{:.2e}", cfg.lr_max));
626 output::kv(" Total steps", cfg.total_steps.to_string());
627 output::kv(" Warmup steps", cfg.warmup_steps.to_string());
628 output::kv(
629 " Batch × seq",
630 format!("{} × {}", cfg.batch_size, cfg.seq_length),
631 );
632 output::kv(" Steps / epoch", cfg.steps_per_epoch.to_string());
633 output::kv(" Seed", cfg.seed.to_string());
634 output::kv(" Target val_loss", format!("{:.2}", cfg.target_val_loss));
635 println!();
636}
637
638fn report<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
639 status: &RunStatus,
640 loop_: &PretrainLoop<S, V>,
641 json_output: bool,
642) -> Result<()> {
643 if json_output {
644 let report = PretrainReport::from(status, loop_);
645 let json = serde_json::to_string_pretty(&report)
646 .map_err(|e| CliError::InvalidFormat(e.to_string()))?;
647 println!("{json}");
648 return Ok(());
649 }
650
651 output::section("Run Result");
652 match status {
653 RunStatus::Ok {
654 final_val_loss,
655 epochs_completed,
656 } => {
657 println!(
658 " {} CONVERGED final val_loss={:.4} after {} epoch(s)",
659 "OK".green().bold(),
660 final_val_loss,
661 epochs_completed
662 );
663 }
664 RunStatus::EarlyStop {
665 best_val_loss,
666 epochs_completed,
667 } => {
668 println!(
669 " {} EARLY_STOP best val_loss={:.4} after {} epoch(s)",
670 "OK".yellow().bold(),
671 best_val_loss,
672 epochs_completed
673 );
674 }
675 RunStatus::Aborted(abort) => {
676 println!(" {} ABORTED {}", "FAIL".red().bold(), abort);
677 }
678 }
679 output::kv(" Steps recorded", loop_.step_metrics().len().to_string());
680 output::kv(
681 " Epochs recorded",
682 loop_.epoch_artifacts().len().to_string(),
683 );
684 println!();
685 Ok(())
686}
687
688#[derive(serde::Serialize)]
689struct PretrainReport {
690 status: String,
691 detail: Option<String>,
692 final_val_loss: Option<f32>,
693 epochs_completed: usize,
694 steps_recorded: usize,
695 val_loss_history: Vec<f32>,
696 per_step_metrics: Vec<entrenar::train::pretrain::StepMetrics>,
706}
707
708impl PretrainReport {
709 fn from<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
710 status: &RunStatus,
711 loop_: &PretrainLoop<S, V>,
712 ) -> Self {
713 let (status_name, detail, final_val_loss, epochs_completed) = match status {
714 RunStatus::Ok {
715 final_val_loss,
716 epochs_completed,
717 } => (
718 "OK".to_string(),
719 None,
720 Some(*final_val_loss),
721 *epochs_completed,
722 ),
723 RunStatus::EarlyStop {
724 best_val_loss,
725 epochs_completed,
726 } => (
727 "EARLY_STOP".to_string(),
728 None,
729 Some(*best_val_loss),
730 *epochs_completed,
731 ),
732 RunStatus::Aborted(abort) => (
733 "ABORTED".to_string(),
734 Some(abort.to_string()),
735 None,
736 loop_.epoch_artifacts().len(),
737 ),
738 };
739 PretrainReport {
740 status: status_name,
741 detail,
742 final_val_loss,
743 epochs_completed,
744 steps_recorded: loop_.step_metrics().len(),
745 val_loss_history: loop_.val_loss_history().to_vec(),
746 per_step_metrics: loop_.step_metrics().to_vec(),
747 }
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754 use tempfile::TempDir;
755
756 fn stage_vocab_json(dir: &std::path::Path, n: usize) {
761 std::fs::create_dir_all(dir).expect("mkdir tokenizer dir");
762 let mut obj = serde_json::Map::with_capacity(n);
763 for i in 0..n {
764 obj.insert(format!("t{i}"), serde_json::Value::from(i as u64));
765 }
766 let json = serde_json::to_string(&obj).expect("serialize");
767 std::fs::write(dir.join("vocab.json"), json).expect("write vocab.json");
768 }
769
770 #[test]
771 fn preflight_accepts_matching_vocab() {
772 let tmp = TempDir::new().expect("tempdir");
775 stage_vocab_json(tmp.path(), Llama370MConfig::VOCAB_SIZE);
776 preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
777 .expect("matching vocab must pass GATE-ARCH-370M-011");
778 }
779
780 #[test]
781 fn preflight_rejects_tokenizer_vocab_mismatch() {
782 let tmp = TempDir::new().expect("tempdir");
789 let mismatch = Llama370MConfig::VOCAB_SIZE - 1;
790 stage_vocab_json(tmp.path(), mismatch);
791 let err =
792 preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
793 .expect_err("tokenizer/model vocab mismatch must be rejected");
794 match err {
795 CliError::ValidationFailed(msg) => {
796 assert!(
797 msg.contains("GATE-ARCH-370M-011"),
798 "msg must cite gate: {msg}"
799 );
800 assert!(
801 msg.contains(&mismatch.to_string()),
802 "msg must name tokenizer vocab: {msg}"
803 );
804 assert!(
805 msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
806 "msg must name model vocab: {msg}"
807 );
808 }
809 other => panic!("unexpected error: {other:?}"),
810 }
811 }
812
813 #[test]
814 fn preflight_rejects_missing_vocab_json() {
815 let tmp = TempDir::new().expect("tempdir");
819 let err =
820 preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
821 .expect_err("missing vocab.json must be rejected");
822 match err {
823 CliError::ValidationFailed(msg) => {
824 assert!(
825 msg.contains("GATE-ARCH-370M-011"),
826 "msg must cite gate: {msg}"
827 );
828 assert!(
829 msg.contains("cannot read"),
830 "msg must name I/O failure: {msg}"
831 );
832 }
833 other => panic!("unexpected error: {other:?}"),
834 }
835 }
836
837 #[test]
844 fn preflight_qwen_vocab_passes_with_qwen_target() {
845 const QWEN2_VOCAB_SIZE: usize = 151_936;
846 let tmp = TempDir::new().expect("tempdir");
847 stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
848 preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN2_VOCAB_SIZE, true).expect(
852 "Qwen tokenizer (151_936) MUST pass preflight when target is Qwen-shaped — \
853 this is the load-bearing claim of §49 fine-tune from a Qwen2.5 init checkpoint",
854 );
855 }
856
857 #[test]
864 fn preflight_qwen_vocab_fails_with_llama_target() {
865 const QWEN2_VOCAB_SIZE: usize = 151_936;
866 let tmp = TempDir::new().expect("tempdir");
867 stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
868 let err = preflight_tokenizer_vocab_matches_target(
871 tmp.path(),
872 Llama370MConfig::VOCAB_SIZE,
873 false,
874 )
875 .expect_err(
876 "Qwen tokenizer (151_936) MUST FAIL preflight when target is Llama370M (50_257) — \
877 silent-pass would corrupt training",
878 );
879 match err {
880 CliError::ValidationFailed(msg) => {
881 assert!(
882 msg.contains(&QWEN2_VOCAB_SIZE.to_string()),
883 "msg must name Qwen vocab size 151_936: {msg}"
884 );
885 assert!(
886 msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
887 "msg must name target Llama vocab size 50_257: {msg}"
888 );
889 }
890 other => panic!("unexpected error: {other:?}"),
891 }
892 }
893
894 #[test]
899 fn preflight_qwen_reserved_slots_pass_under_polymorphic_init() {
900 const QWEN_TOKENIZER_EFFECTIVE: usize = 151_665;
901 const QWEN_DECLARED_VOCAB: usize = 151_936;
902 let tmp = TempDir::new().expect("tempdir");
903 stage_vocab_json(tmp.path(), QWEN_TOKENIZER_EFFECTIVE);
904
905 preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, true).expect(
907 "FALSIFY-APR-PRETRAIN-ARCH-009: HF reserved-slot tokenizer (151_665 ≤ 151_936) \
908 MUST pass preflight under polymorphic init path (§55 relaxed bound)",
909 );
910
911 let err =
913 preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, false)
914 .expect_err(
915 "FALSIFY-APR-PRETRAIN-ARCH-009 dual: from-scratch path MUST keep strict ==",
916 );
917 match err {
918 CliError::ValidationFailed(msg) => {
919 assert!(
920 msg.contains("GATE-ARCH-370M-011")
921 && msg.contains(&QWEN_TOKENIZER_EFFECTIVE.to_string())
922 && msg.contains(&QWEN_DECLARED_VOCAB.to_string()),
923 "strict-mode error must name gate + both sizes: {msg}"
924 );
925 }
926 other => panic!("unexpected error: {other:?}"),
927 }
928 }
929
930 #[test]
935 fn preflight_oversized_tokenizer_rejected_even_under_polymorphic_init() {
936 const QWEN_DECLARED_VOCAB: usize = 151_936;
937 let oversized = QWEN_DECLARED_VOCAB + 100;
938 let tmp = TempDir::new().expect("tempdir");
939 stage_vocab_json(tmp.path(), oversized);
940
941 let err = preflight_tokenizer_vocab_matches_target(
942 tmp.path(),
943 QWEN_DECLARED_VOCAB,
944 true, )
946 .expect_err(
947 "FALSIFY-APR-PRETRAIN-ARCH-010: oversized tokenizer MUST fail-fast even under \
948 polymorphic init (OOB safety; relaxed bound is ≤ not <)",
949 );
950 match err {
951 CliError::ValidationFailed(msg) => {
952 assert!(
953 msg.contains("RELAXED") && msg.contains("OOB"),
954 "polymorphic-mode error must cite RELAXED + OOB: {msg}"
955 );
956 }
957 other => panic!("unexpected error: {other:?}"),
958 }
959 }
960
961 #[test]
977 fn drive_real_cuda_init_path_fail_fasts_with_falsifier_citation() {
978 let msg = FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG;
979 assert!(
980 msg.contains("FALSIFY-APR-PRETRAIN-INIT-CUDA-001"),
981 "error message MUST cite the falsifier id (auditability): {msg}"
982 );
983 assert!(
984 msg.contains("not yet wired for --device cuda"),
985 "error message MUST contain the canonical 'not yet wired' \
986 phrase so operators recognize the §50.4 step 5f.5 gap: {msg}"
987 );
988 assert!(
989 msg.contains("step 5f.5 follow-up"),
990 "error message MUST reference the 5f.5 follow-up so future \
991 agents know which step retires this guard: {msg}"
992 );
993 assert!(
994 msg.contains("--device cpu") && msg.contains("OR omit --init"),
995 "error message MUST suggest both workarounds (CPU device OR \
996 omit --init for from-scratch CUDA): {msg}"
997 );
998 }
999
1000 #[test]
1001 fn synthetic_pretrain_end_to_end_happy_path() {
1002 let tmp = TempDir::new().expect("tempdir");
1003 let dataset = tmp.path().join("data.jsonl");
1004 let tokenizer = tmp.path().join("tok");
1005 let run_dir = tmp.path().join("run");
1006
1007 let result = run(
1008 &dataset,
1009 &tokenizer,
1010 &run_dir,
1011 PretrainMode::Finetune,
1012 Some(5.0e-5),
1013 25,
1014 Some(5),
1015 2,
1016 4,
1017 5,
1018 42,
1019 Some(2.2),
1020 50257,
1021 true,
1022 "cpu",
1023 None,
1024 true,
1025 );
1026 assert!(
1027 result.is_ok(),
1028 "synthetic pretrain end-to-end must succeed: got {result:?}"
1029 );
1030 }
1031
1032 #[test]
1033 fn real_mode_empty_dataset_dir_errors() {
1034 let tmp = TempDir::new().expect("tempdir");
1040 let tok_dir = tmp.path().join("tok");
1041 stage_vocab_json(&tok_dir, Llama370MConfig::VOCAB_SIZE);
1042 let err = run(
1043 tmp.path(),
1044 &tok_dir,
1045 tmp.path(),
1046 PretrainMode::Finetune,
1047 Some(5.0e-5),
1048 10,
1049 Some(2),
1050 2,
1051 4,
1052 5,
1053 42,
1054 Some(2.2),
1055 50257,
1056 false,
1057 "cpu",
1058 None,
1059 true,
1060 )
1061 .expect_err("empty dataset dir must fail to initialise the shard iterator");
1062 match err {
1063 CliError::ValidationFailed(msg) => {
1064 assert!(
1065 msg.contains("shard iterator init failed"),
1066 "unexpected message: {msg}"
1067 );
1068 }
1069 other => panic!("unexpected error: {other:?}"),
1070 }
1071 }
1072
1073 #[test]
1074 fn invalid_target_val_loss_rejected() {
1075 let tmp = TempDir::new().expect("tempdir");
1076 let err = run(
1077 tmp.path(),
1078 tmp.path(),
1079 tmp.path(),
1080 PretrainMode::Finetune,
1081 Some(5.0e-5),
1082 10,
1083 Some(2),
1084 2,
1085 4,
1086 5,
1087 42,
1088 Some(-1.0),
1089 50257,
1090 true,
1091 "cpu",
1092 None,
1093 true,
1094 )
1095 .expect_err("negative target_val_loss must be rejected");
1096 assert!(matches!(err, CliError::ValidationFailed(_)));
1097 }
1098
1099 #[test]
1108 fn mode_finetune_is_default_and_matches_contract() {
1109 let hp = mode_defaults(PretrainMode::Finetune, 50257, None, None, None);
1113 assert_eq!(hp.regime, TrainingRegime::Finetune);
1114 assert!(
1115 (hp.lr_max - 5.0e-5).abs() < 1.0e-12,
1116 "lr_max={} must equal finetune default 5e-5",
1117 hp.lr_max
1118 );
1119 assert_eq!(hp.warmup_steps, 100);
1120 assert!(
1121 (hp.target_val_loss - 2.2).abs() < 1.0e-6,
1122 "target_val_loss={} must equal finetune default 2.2",
1123 hp.target_val_loss
1124 );
1125 }
1126
1127 #[test]
1128 fn mode_from_scratch_applies_all_four_defaults() {
1129 let hp = mode_defaults(PretrainMode::FromScratch, 50257, None, None, None);
1133 assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1134 assert!(
1135 (hp.lr_max - 3.0e-4).abs() < 1.0e-12,
1136 "lr_max={} must equal from_scratch default 3e-4",
1137 hp.lr_max
1138 );
1139 assert_eq!(hp.warmup_steps, 1000);
1140 assert!(
1141 (hp.target_val_loss - 3.0).abs() < 1.0e-6,
1142 "target_val_loss={} must equal from_scratch default 3.0",
1143 hp.target_val_loss
1144 );
1145 }
1146
1147 #[test]
1148 fn mode_from_scratch_honors_explicit_lr_override() {
1149 let hp = mode_defaults(PretrainMode::FromScratch, 50257, Some(1.0e-4), None, None);
1154 assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1155 assert!(
1156 (hp.lr_max - 1.0e-4).abs() < 1.0e-12,
1157 "lr_max={} must equal explicit override 1e-4",
1158 hp.lr_max
1159 );
1160 assert_eq!(hp.warmup_steps, 1000);
1162 assert!((hp.target_val_loss - 3.0).abs() < 1.0e-6);
1163 }
1164
1165 fn parse_pretrain_synthetic(extra: &[&str]) -> bool {
1177 let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1182 std::thread::Builder::new()
1183 .stack_size(16 * 1024 * 1024)
1184 .spawn(move || {
1185 use clap::Parser;
1186 let mut argv: Vec<String> = vec![
1187 "apr".to_string(),
1188 "pretrain".to_string(),
1189 "--dataset".to_string(),
1190 "/tmp/_gate_train_010/ds".to_string(),
1191 "--tokenizer".to_string(),
1192 "/tmp/_gate_train_010/tok".to_string(),
1193 "--run-dir".to_string(),
1194 "/tmp/_gate_train_010/run".to_string(),
1195 ];
1196 argv.extend(extra);
1197 let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1198 match *cli.command {
1199 crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1200 synthetic,
1201 ..
1202 }) => synthetic,
1203 other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1204 }
1205 })
1206 .expect("spawn parse thread")
1207 .join()
1208 .expect("parse thread must not panic")
1209 }
1210
1211 #[test]
1212 fn cli_pretrain_defaults_to_real_compute() {
1213 assert!(
1216 !parse_pretrain_synthetic(&[]),
1217 "INV-TRAIN-010: `apr pretrain` (no --synthetic) must parse to synthetic=false"
1218 );
1219 }
1220
1221 #[test]
1222 fn cli_pretrain_synthetic_flag_routes_to_synthetic() {
1223 assert!(
1225 parse_pretrain_synthetic(&["--synthetic"]),
1226 "INV-TRAIN-010: `apr pretrain --synthetic` must parse to synthetic=true"
1227 );
1228 }
1229
1230 fn parse_pretrain_device(extra: &[&str]) -> String {
1241 let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1242 std::thread::Builder::new()
1243 .stack_size(16 * 1024 * 1024)
1244 .spawn(move || {
1245 use clap::Parser;
1246 let mut argv: Vec<String> = vec![
1247 "apr".to_string(),
1248 "pretrain".to_string(),
1249 "--dataset".to_string(),
1250 "/tmp/_gputrain_device/ds".to_string(),
1251 "--tokenizer".to_string(),
1252 "/tmp/_gputrain_device/tok".to_string(),
1253 "--run-dir".to_string(),
1254 "/tmp/_gputrain_device/run".to_string(),
1255 ];
1256 argv.extend(extra);
1257 let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1258 match *cli.command {
1259 crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1260 device, ..
1261 }) => device,
1262 other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1263 }
1264 })
1265 .expect("spawn parse thread")
1266 .join()
1267 .expect("parse thread must not panic")
1268 }
1269
1270 #[test]
1271 fn cli_pretrain_device_defaults_to_auto() {
1272 assert_eq!(
1277 parse_pretrain_device(&[]),
1278 "auto",
1279 "gpu-training-backend-v1 INV-GPUTRAIN-002: default --device must be `auto`",
1280 );
1281 }
1282
1283 #[test]
1284 fn cli_pretrain_device_accepts_cpu() {
1285 assert_eq!(parse_pretrain_device(&["--device", "cpu"]), "cpu");
1287 }
1288
1289 #[test]
1290 fn cli_pretrain_device_accepts_cuda_index() {
1291 assert_eq!(parse_pretrain_device(&["--device", "cuda:7"]), "cuda:7");
1294 }
1295
1296 fn parse_pretrain_init(extra: &[&str]) -> Option<std::path::PathBuf> {
1309 let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1310 std::thread::Builder::new()
1311 .stack_size(16 * 1024 * 1024)
1312 .spawn(move || {
1313 use clap::Parser;
1314 let mut argv: Vec<String> = vec![
1315 "apr".to_string(),
1316 "pretrain".to_string(),
1317 "--dataset".to_string(),
1318 "/tmp/_init_flag/ds".to_string(),
1319 "--tokenizer".to_string(),
1320 "/tmp/_init_flag/tok".to_string(),
1321 "--run-dir".to_string(),
1322 "/tmp/_init_flag/run".to_string(),
1323 ];
1324 argv.extend(extra);
1325 let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1326 match *cli.command {
1327 crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1328 init, ..
1329 }) => init,
1330 other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1331 }
1332 })
1333 .expect("spawn parse thread")
1334 .join()
1335 .expect("parse thread must not panic")
1336 }
1337
1338 #[test]
1340 fn pretrain_init_flag_absent_parses_to_none() {
1341 assert_eq!(
1344 parse_pretrain_init(&[]),
1345 None,
1346 "FALSIFY-APR-PRETRAIN-INIT-001/002: default --init must be None (no silent default)"
1347 );
1348 }
1349
1350 #[test]
1352 fn pretrain_init_flag_parses_path() {
1353 let parsed = parse_pretrain_init(&["--init", "/tmp/foo.apr"]);
1354 assert_eq!(
1355 parsed.as_deref().and_then(|p| p.to_str()),
1356 Some("/tmp/foo.apr"),
1357 "FALSIFY-APR-PRETRAIN-INIT-001: --init <PATH> must round-trip through clap"
1358 );
1359 }
1360
1361 #[test]
1364 fn pretrain_init_missing_file_errors() {
1365 let tmp = TempDir::new().expect("tempdir");
1366 let missing = tmp.path().join("does-not-exist.apr");
1367 let err = run(
1368 tmp.path(),
1369 tmp.path(),
1370 tmp.path(),
1371 PretrainMode::Finetune,
1372 Some(5.0e-5),
1373 10,
1374 Some(2),
1375 2,
1376 4,
1377 5,
1378 42,
1379 Some(2.2),
1380 50257,
1381 true,
1382 "cpu",
1383 Some(&missing),
1384 true,
1385 )
1386 .expect_err("missing --init file must be rejected");
1387 match err {
1388 CliError::ValidationFailed(msg) => {
1389 assert!(
1390 msg.contains("FALSIFY-APR-PRETRAIN-INIT-003"),
1391 "msg must cite falsifier id: {msg}"
1392 );
1393 assert!(
1394 msg.contains("does-not-exist.apr"),
1395 "msg must name the missing path: {msg}"
1396 );
1397 }
1398 other => panic!("unexpected error: {other:?}"),
1399 }
1400 }
1401
1402 #[test]
1404 fn pretrain_init_bad_magic_errors() {
1405 let tmp = TempDir::new().expect("tempdir");
1406 let bad = tmp.path().join("not-an-apr.bin");
1407 std::fs::write(&bad, b"GGUF\x00\x00\x00\x00\x00\x00\x00\x00")
1408 .expect("write fixture file");
1409 let err = run(
1410 tmp.path(),
1411 tmp.path(),
1412 tmp.path(),
1413 PretrainMode::Finetune,
1414 Some(5.0e-5),
1415 10,
1416 Some(2),
1417 2,
1418 4,
1419 5,
1420 42,
1421 Some(2.2),
1422 50257,
1423 true,
1424 "cpu",
1425 Some(&bad),
1426 true,
1427 )
1428 .expect_err("invalid magic bytes must be rejected");
1429 match err {
1430 CliError::ValidationFailed(msg) => {
1431 assert!(
1432 msg.contains("FALSIFY-APR-PRETRAIN-INIT-004"),
1433 "msg must cite falsifier id: {msg}"
1434 );
1435 assert!(
1436 msg.contains("not a valid APR file"),
1437 "msg must describe magic mismatch: {msg}"
1438 );
1439 }
1440 other => panic!("unexpected error: {other:?}"),
1441 }
1442 }
1443
1444 #[test]
1446 fn pretrain_init_empty_file_errors() {
1447 let tmp = TempDir::new().expect("tempdir");
1448 let empty = tmp.path().join("empty.apr");
1449 std::fs::write(&empty, b"").expect("write empty fixture");
1450 let err = run(
1451 tmp.path(),
1452 tmp.path(),
1453 tmp.path(),
1454 PretrainMode::Finetune,
1455 Some(5.0e-5),
1456 10,
1457 Some(2),
1458 2,
1459 4,
1460 5,
1461 42,
1462 Some(2.2),
1463 50257,
1464 true,
1465 "cpu",
1466 Some(&empty),
1467 true,
1468 )
1469 .expect_err("empty file must be rejected (cannot contain magic bytes)");
1470 assert!(matches!(err, CliError::ValidationFailed(_)));
1471 }
1472
1473 #[test]
1480 fn pretrain_init_valid_magic_but_bogus_metadata_fails_at_arch_extraction() {
1481 let tmp = TempDir::new().expect("tempdir");
1482 let valid = tmp.path().join("v2-valid-magic-bogus-metadata.apr");
1483 std::fs::write(&valid, b"APR\x00\x00\x00\x00\x00\x00\x00\x00\x00")
1486 .expect("write fixture file");
1487 let err = run(
1488 tmp.path(),
1489 tmp.path(),
1490 tmp.path(),
1491 PretrainMode::Finetune,
1492 Some(5.0e-5),
1493 10,
1494 Some(2),
1495 2,
1496 4,
1497 5,
1498 42,
1499 Some(2.2),
1500 50257,
1501 true,
1502 "cpu",
1503 Some(&valid),
1504 true,
1505 )
1506 .expect_err("bogus metadata must NOT silently random-init");
1507 match err {
1508 CliError::ValidationFailed(msg) => {
1509 assert!(
1510 !msg.contains("not yet wired"),
1511 "the legacy step-5-partial guard must be retired: {msg}"
1512 );
1513 }
1517 other => panic!("unexpected error: {other:?}"),
1518 }
1519 }
1520
1521 #[test]
1525 fn pretrain_init_v1_magic_aprn_passes_validate_init_apr_path() {
1526 let tmp = TempDir::new().expect("tempdir");
1527 let v1 = tmp.path().join("v1-aprn.apr");
1528 std::fs::write(&v1, b"APRN\x00\x00\x00\x00").expect("write fixture file");
1529 let result = validate_init_apr_path(&v1);
1530 assert!(
1531 result.is_ok(),
1532 "APRN magic must pass validate_init_apr_path; got {result:?}"
1533 );
1534 }
1535}