1use crate::error::{CliError, Result};
16use crate::output;
17use clap::ValueEnum;
18use colored::Colorize;
19use entrenar::models::llama_370m::{
20 assert_tokenizer_vocab_matches_model, assert_tokenizer_vocab_within_model_bound,
21 Llama370MConfig,
22};
23use entrenar::train::device::{resolve_device, Device};
24use entrenar::train::pretrain::{
25 CheckpointFn, LinearDecaySynthetic, PretrainAbort, PretrainConfig, PretrainLoop, RunStatus,
26 ScriptedVal, StepFn, TrainingRegime, ValFn,
27};
28use entrenar::train::pretrain_real::{
29 build_shared_trainer, build_shared_trainer_with_init, AprCheckpointFn, RealStepFn, RealValFn,
30};
31use entrenar::train::shard_reader::ShardBatchIter;
32use entrenar::train::transformer_trainer::LMBatch;
33use entrenar::transformer::TransformerConfig;
34use std::path::Path;
35
36const HELD_OUT_BATCHES: usize = 16;
49
50pub(crate) const FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG: &str =
69 "FALSIFY-APR-PRETRAIN-INIT-CUDA-001: --init is wired for --device cuda \
70 via build_shared_cuda_trainer_with_init (5f.5 SHIPPED); operator can pass \
71 --init <PATH> --device cuda for end-to-end GPU fine-tune dispatch.";
72
73#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
78pub enum PretrainMode {
79 Finetune,
81 FromScratch,
83}
84
85#[derive(Clone, Debug, PartialEq)]
90pub(crate) struct ResolvedHp {
91 pub regime: TrainingRegime,
92 pub lr_max: f32,
93 pub warmup_steps: usize,
94 pub target_val_loss: f32,
95}
96
97fn checkpoint_name_and_arch(init_arch: Option<&TransformerConfig>) -> (String, String) {
110 match init_arch {
111 Some(arch) => {
112 let hf_arch = arch
113 .hf_architecture
114 .clone()
115 .unwrap_or_else(|| "LlamaForCausalLM".to_string());
116 let name = arch
120 .hf_model_type
121 .as_deref()
122 .map_or_else(|| "model-pretrain".to_string(), |t| format!("{t}-pretrain"));
123 (name, hf_arch)
124 }
125 None => (
126 "llama-370m-pretrain".to_string(),
127 "LlamaForCausalLM".to_string(),
128 ),
129 }
130}
131
132fn estimate_param_count(arch: &TransformerConfig) -> u64 {
144 let vocab = arch.vocab_size as u64;
145 let hidden = arch.hidden_size as u64;
146 let inter = arch.intermediate_size as u64;
147 let layers = arch.num_hidden_layers as u64;
148 let embed = vocab.saturating_mul(hidden);
149 let attn_per_layer = 4u64.saturating_mul(hidden).saturating_mul(hidden);
150 let ffn_per_layer = 3u64.saturating_mul(hidden).saturating_mul(inter);
151 let per_layer = attn_per_layer.saturating_add(ffn_per_layer);
152 let layer_total = layers.saturating_mul(per_layer);
153 embed.saturating_add(layer_total).saturating_add(hidden)
154}
155
156pub(crate) fn mode_defaults(
157 mode: PretrainMode,
158 vocab_size: u32,
159 lr_override: Option<f32>,
160 warmup_override: Option<usize>,
161 target_override: Option<f32>,
162) -> ResolvedHp {
163 let (regime, lr_def, warmup_def, target_def) = match mode {
164 PretrainMode::Finetune => (TrainingRegime::Finetune, 5.0e-5, 100, 2.2),
165 PretrainMode::FromScratch => (
166 TrainingRegime::FromScratch { vocab_size },
167 3.0e-4,
168 1000,
169 3.0,
170 ),
171 };
172 ResolvedHp {
173 regime,
174 lr_max: lr_override.unwrap_or(lr_def),
175 warmup_steps: warmup_override.unwrap_or(warmup_def),
176 target_val_loss: target_override.unwrap_or(target_def),
177 }
178}
179
180#[allow(clippy::too_many_arguments)]
182pub(crate) fn run(
183 dataset: &Path,
184 tokenizer: &Path,
185 run_dir: &Path,
186 mode: PretrainMode,
187 lr: Option<f32>,
188 num_steps: usize,
189 warmup_steps: Option<usize>,
190 batch_size: usize,
191 seq_length: usize,
192 steps_per_epoch: usize,
193 seed: u64,
194 target_val_loss: Option<f32>,
195 vocab_size: u32,
196 synthetic: bool,
197 device: &str,
198 init: Option<&Path>,
199 force_under_provisioned: bool,
200 val_shard: Option<&Path>,
201 json_output: bool,
202) -> Result<()> {
203 let resolved_device =
209 resolve_device(device).map_err(|e| CliError::ValidationFailed(e.to_string()))?;
210
211 let init_arch: Option<TransformerConfig> = if let Some(init_path) = init {
218 validate_init_apr_path(init_path)?;
219 Some(
220 crate::commands::model_config::read_apr_architecture(init_path).ok_or_else(|| {
221 CliError::ValidationFailed(format!(
222 "FALSIFY-APR-PRETRAIN-INIT-005: --init APR file at {} has missing or invalid \
223 architecture metadata (hidden_size, num_heads, num_layers, vocab_size, etc). \
224 Cannot extract TransformerConfig per apr-pretrain-arch-polymorphic-v1 \
225 §arch_extraction_signature.",
226 init_path.display()
227 ))
228 })?,
229 )
230 } else {
231 None
232 };
233
234 let hp = mode_defaults(mode, vocab_size, lr, warmup_steps, target_val_loss);
235
236 if let Some(arch) = init_arch.as_ref() {
251 let n_params = estimate_param_count(arch);
252 let d_tokens = (num_steps as u64)
253 .saturating_mul(batch_size as u64)
254 .saturating_mul(seq_length as u64);
255 let ratio = d_tokens as f64 / n_params as f64;
256 let suggested_steps = if batch_size > 0 && seq_length > 0 {
257 (20 * n_params) / (batch_size as u64 * seq_length as u64)
258 } else {
259 0
260 };
261
262 if ratio < 10.0 && !force_under_provisioned {
263 return Err(CliError::ValidationFailed(format!(
264 "[P0-J] Chinchilla hard gate (chinchilla-gate-v1): \
265 train tokens D = {} ({:.1}M) is {:.3}× param count N = {} ({:.1}M); \
266 Chinchilla compute-optimal target is D ≈ 20·N (Hoffmann et al. 2022, arXiv:2203.15556). \
267 Run REJECTED: D/N < 10× will produce mode collapse / repetitive degeneration \
268 (Holtzman et al. 2019, arXiv:1904.09751). \
269 Increase --num-steps to ~{} OR widen --dataset corpus OR reduce model size. \
270 To bypass anyway (e.g. ablation studies, resumed runs), pass --force-under-provisioned.",
271 d_tokens,
272 d_tokens as f64 / 1e6,
273 ratio,
274 n_params,
275 n_params as f64 / 1e6,
276 suggested_steps,
277 )));
278 }
279
280 if ratio < 10.0 {
281 eprintln!(
284 "[P0-J] Chinchilla gate BYPASSED via --force-under-provisioned: \
285 D = {} ({:.1}M) is {:.3}× N = {} ({:.1}M). \
286 Run will likely produce repetitive/degenerate output. \
287 You explicitly opted in.",
288 d_tokens,
289 d_tokens as f64 / 1e6,
290 ratio,
291 n_params,
292 n_params as f64 / 1e6,
293 );
294 } else if ratio < 20.0 {
295 eprintln!(
298 "[P1-A] Chinchilla gate WARNING: D = {} ({:.1}M) is {:.1}× N = {} ({:.1}M); \
299 below compute-optimal 20·N target — model has room for more training. \
300 Suggested --num-steps for 20·N: ~{}.",
301 d_tokens,
302 d_tokens as f64 / 1e6,
303 ratio,
304 n_params,
305 n_params as f64 / 1e6,
306 suggested_steps,
307 );
308 }
309 }
310
311 if hp.target_val_loss <= 0.0 {
313 return Err(CliError::ValidationFailed(format!(
314 "target_val_loss must be positive, got {}",
315 hp.target_val_loss
316 )));
317 }
318 if num_steps == 0 {
319 return Err(CliError::ValidationFailed(
320 "num_steps must be > 0".to_string(),
321 ));
322 }
323 if steps_per_epoch == 0 {
324 return Err(CliError::ValidationFailed(
325 "steps_per_epoch must be > 0".to_string(),
326 ));
327 }
328
329 let config = PretrainConfig {
330 dataset_path: dataset.to_path_buf(),
331 tokenizer_dir: tokenizer.to_path_buf(),
332 run_dir: run_dir.to_path_buf(),
333 lr_max: hp.lr_max,
334 lr_min: (hp.lr_max * 1.0e-2).max(1.0e-7),
335 warmup_steps: hp.warmup_steps,
336 total_steps: num_steps,
337 batch_size,
338 seq_length,
339 steps_per_epoch,
340 seed,
341 grad_clip: 1.0,
342 weight_decay: 0.01,
343 target_val_loss: hp.target_val_loss,
344 patience_epochs: 5,
352 min_epochs_before_early_stop: 3,
358 regime: hp.regime,
359 };
360
361 if !json_output {
362 print_header(&config);
363 output::kv(" Device", resolved_device.to_string());
368 println!();
369 }
370
371 let status = if synthetic {
372 drive_synthetic(
373 config.clone(),
374 num_steps,
375 steps_per_epoch,
376 hp.target_val_loss,
377 json_output,
378 )?
379 } else {
380 drive_real(
381 config.clone(),
382 dataset,
383 hp.lr_max,
384 seq_length,
385 batch_size,
386 seed,
387 resolved_device,
388 json_output,
389 init_arch.as_ref(),
390 init,
391 val_shard,
392 )?
393 };
394
395 match status {
398 RunStatus::Aborted(abort) => Err(abort_to_err(&abort)),
399 RunStatus::Ok { .. } | RunStatus::EarlyStop { .. } => Ok(()),
400 }
401}
402
403fn drive_synthetic(
407 config: PretrainConfig,
408 num_steps: usize,
409 steps_per_epoch: usize,
410 target_val_loss: f32,
411 json_output: bool,
412) -> Result<RunStatus> {
413 let step_fn = LinearDecaySynthetic {
414 start_loss: (target_val_loss * 2.0).max(1.5),
415 decay_per_step: (target_val_loss * 0.01).max(1.0e-4),
416 grad_norm: 0.8,
417 };
418 let num_epochs = num_steps.div_ceil(steps_per_epoch);
419 let mut sequence = Vec::with_capacity(num_epochs + 2);
420 let start_val = (target_val_loss * 1.8).max(3.0);
421 for i in 0..(num_epochs + 2) {
422 let t = i as f32 / (num_epochs.max(1) as f32);
423 sequence.push(target_val_loss + (start_val - target_val_loss) * (1.0 - t).max(0.0));
424 }
425 let val_fn = ScriptedVal { sequence };
426 run_and_report(config, step_fn, val_fn, None, json_output)
428}
429
430fn validate_init_apr_path(path: &Path) -> Result<()> {
440 let mut file = std::fs::File::open(path).map_err(|e| {
441 CliError::ValidationFailed(format!(
442 "FALSIFY-APR-PRETRAIN-INIT-003: --init path does not exist or is unreadable: {} ({e})",
443 path.display()
444 ))
445 })?;
446 let mut magic = [0u8; 4];
447 use std::io::Read;
448 file.read_exact(&mut magic).map_err(|e| {
449 CliError::ValidationFailed(format!(
450 "FALSIFY-APR-PRETRAIN-INIT-004: --init file too short to contain APR magic bytes: {} ({e})",
451 path.display()
452 ))
453 })?;
454 const APR_MAGIC_V2: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
458 const APR_MAGIC_V1: [u8; 4] = [0x41, 0x50, 0x52, 0x4E];
459 if magic != APR_MAGIC_V2 && magic != APR_MAGIC_V1 {
460 return Err(CliError::ValidationFailed(format!(
461 "FALSIFY-APR-PRETRAIN-INIT-004: --init file is not a valid APR file (magic={:02X?}, expected {:02X?} or {:02X?}): {}",
462 magic, APR_MAGIC_V2, APR_MAGIC_V1, path.display()
463 )));
464 }
465 Ok(())
466}
467
468fn preflight_tokenizer_vocab_matches_target(
486 tokenizer_dir: &Path,
487 target_vocab_size: usize,
488 init_is_some: bool,
489) -> Result<()> {
490 let vocab_path = tokenizer_dir.join("vocab.json");
491 let vocab_json = std::fs::read_to_string(&vocab_path).map_err(|e| {
492 CliError::ValidationFailed(format!(
493 "GATE-ARCH-370M-011 pre-flight: cannot read {} ({e})",
494 vocab_path.display()
495 ))
496 })?;
497 let vocab: serde_json::Map<String, serde_json::Value> = serde_json::from_str(&vocab_json)
498 .map_err(|e| {
499 CliError::ValidationFailed(format!(
500 "GATE-ARCH-370M-011 pre-flight: {} is not a valid vocab.json: {e}",
501 vocab_path.display()
502 ))
503 })?;
504 if init_is_some {
509 assert_tokenizer_vocab_within_model_bound(vocab.len(), target_vocab_size)
510 .map_err(CliError::ValidationFailed)
511 } else {
512 assert_tokenizer_vocab_matches_model(vocab.len(), target_vocab_size)
513 .map_err(CliError::ValidationFailed)
514 }
515}
516
517#[allow(clippy::too_many_arguments)]
525fn drive_real(
526 config: PretrainConfig,
527 dataset: &Path,
528 lr: f32,
529 seq_length: usize,
530 batch_size: usize,
531 seed: u64,
532 device: Device,
533 json_output: bool,
534 init_arch: Option<&TransformerConfig>,
535 init_path: Option<&Path>,
536 val_shard: Option<&Path>,
537) -> Result<RunStatus> {
538 let target_vocab = init_arch
548 .map(|cfg| cfg.vocab_size)
549 .unwrap_or(Llama370MConfig::VOCAB_SIZE);
550 preflight_tokenizer_vocab_matches_target(
551 &config.tokenizer_dir,
552 target_vocab,
553 init_arch.is_some(),
554 )?;
555
556 let mut iter = ShardBatchIter::new(dataset, batch_size, seq_length, 0, 0)
570 .map_err(|e| {
571 CliError::ValidationFailed(format!(
572 "dataset shard iterator init failed: {e} (path={})",
573 dataset.display()
574 ))
575 })?
576 .with_wrap_around(true)
577 .with_warn_on_wrap_around(true);
583
584 let held_out: Vec<LMBatch> = if let Some(val_dir) = val_shard {
596 let mut val_iter = ShardBatchIter::new(val_dir, batch_size, seq_length, 0, 0)
597 .map_err(|e| {
598 CliError::ValidationFailed(format!(
599 "FALSIFY-PRETRAIN-VAL-SHARD-001: --val-shard iterator init failed: {e} \
600 (path={})",
601 val_dir.display()
602 ))
603 })?
604 .with_wrap_around(false);
609 let mut batches: Vec<LMBatch> = Vec::with_capacity(HELD_OUT_BATCHES);
610 for _ in 0..HELD_OUT_BATCHES {
611 match val_iter.next() {
612 Some(b) => batches.push(b),
613 None => break,
614 }
615 }
616 if batches.is_empty() {
617 return Err(CliError::ValidationFailed(format!(
618 "FALSIFY-PRETRAIN-VAL-SHARD-003: --val-shard {} is too small to yield any \
619 held-out batches at batch_size={} seq_length={}",
620 val_dir.display(),
621 batch_size,
622 seq_length
623 )));
624 }
625 if !json_output {
626 eprintln!(
627 "[P2-F] held-out val source = --val-shard {} ({} batches)",
628 val_dir.display(),
629 batches.len()
630 );
631 }
632 batches
633 } else {
634 let mut batches: Vec<LMBatch> = Vec::with_capacity(HELD_OUT_BATCHES);
637 for _ in 0..HELD_OUT_BATCHES {
638 match iter.next() {
639 Some(b) => batches.push(b),
640 None => break,
641 }
642 }
643 if batches.is_empty() {
644 return Err(CliError::ValidationFailed(format!(
645 "dataset {} is too small to reserve any held-out batches",
646 dataset.display()
647 )));
648 }
649 batches
650 };
651
652 if device.is_cuda() {
653 drive_real_cuda(
670 config,
671 iter,
672 held_out,
673 lr,
674 seq_length,
675 seed,
676 json_output,
677 init_arch,
678 init_path,
679 )
680 } else {
681 drive_real_cpu(
682 config,
683 iter,
684 held_out,
685 lr,
686 seq_length,
687 seed,
688 json_output,
689 init_arch,
690 init_path,
691 )
692 }
693}
694
695#[allow(clippy::too_many_arguments)]
699fn drive_real_cpu(
700 config: PretrainConfig,
701 iter: entrenar::train::shard_reader::ShardBatchIter,
702 held_out: Vec<LMBatch>,
703 lr: f32,
704 seq_length: usize,
705 seed: u64,
706 json_output: bool,
707 init_arch: Option<&TransformerConfig>,
708 init_path: Option<&Path>,
709) -> Result<RunStatus> {
710 let trainer = if init_arch.is_some() || init_path.is_some() {
715 build_shared_trainer_with_init(lr, seq_length, seed, init_arch, init_path)
716 .map_err(CliError::ValidationFailed)?
717 } else {
718 build_shared_trainer(lr, seq_length, seed)
719 };
720 let step_fn = RealStepFn::new(trainer.clone(), Box::new(iter));
721 let val_fn = RealValFn::new(trainer.clone(), held_out);
722 let (ckpt_name, ckpt_arch) = checkpoint_name_and_arch(init_arch);
723 let ckpt: Box<dyn CheckpointFn> =
724 Box::new(AprCheckpointFn::new(trainer, &ckpt_name, &ckpt_arch));
725 run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
726}
727
728#[cfg(feature = "cuda")]
736#[allow(clippy::too_many_arguments)]
737fn drive_real_cuda(
738 config: PretrainConfig,
739 iter: entrenar::train::shard_reader::ShardBatchIter,
740 held_out: Vec<LMBatch>,
741 lr: f32,
742 seq_length: usize,
743 seed: u64,
744 json_output: bool,
745 init_arch: Option<&TransformerConfig>,
746 init_path: Option<&Path>,
747) -> Result<RunStatus> {
748 use entrenar::train::pretrain_real_cuda::{
749 build_shared_cuda_trainer, build_shared_cuda_trainer_with_init, CudaAprCheckpointFn,
750 CudaRealStepFn, CudaRealValFn,
751 };
752 let trainer = if init_arch.is_some() || init_path.is_some() {
758 build_shared_cuda_trainer_with_init(lr, seq_length, seed, init_arch, init_path).map_err(
759 |e| {
760 CliError::ValidationFailed(format!(
761 "GATE-GPUTRAIN-002: CUDA trainer allocation (--init path) failed: {e}. \
762 See contracts/entrenar/gpu-training-backend-v1.yaml and \
763 contracts/apr-pretrain-arch-polymorphic-v1.yaml v1.7.0 \
764 §FALSIFY-APR-PRETRAIN-INIT-CUDA-001 — this path is only \
765 reachable when the binary was built with `--features cuda`.",
766 ))
767 },
768 )?
769 } else {
770 build_shared_cuda_trainer(lr, seq_length, seed).map_err(|e| {
771 CliError::ValidationFailed(format!(
772 "GATE-GPUTRAIN-002: CUDA trainer allocation failed: {e}. \
773 See contracts/entrenar/gpu-training-backend-v1.yaml and \
774 memory/feedback_cuda_feature_footgun.md — this path is \
775 only reachable when the binary was built with `--features cuda`.",
776 ))
777 })?
778 };
779 let step_fn = CudaRealStepFn::new(trainer.clone(), Box::new(iter));
780 let val_fn = CudaRealValFn::new(trainer.clone(), held_out);
781 let (ckpt_name, ckpt_arch) = checkpoint_name_and_arch(init_arch);
784 let ckpt: Box<dyn CheckpointFn> = Box::new(
785 CudaAprCheckpointFn::new(trainer, &ckpt_name, &ckpt_arch)
786 .with_tokenizer_dir(&config.tokenizer_dir),
787 );
788 run_and_report(config, step_fn, val_fn, Some(ckpt), json_output)
789}
790
791#[cfg(not(feature = "cuda"))]
799#[allow(clippy::too_many_arguments)]
800fn drive_real_cuda(
801 _config: PretrainConfig,
802 _iter: entrenar::train::shard_reader::ShardBatchIter,
803 _held_out: Vec<LMBatch>,
804 _lr: f32,
805 _seq_length: usize,
806 _seed: u64,
807 _json_output: bool,
808 _init_arch: Option<&TransformerConfig>,
809 _init_path: Option<&Path>,
810) -> Result<RunStatus> {
811 Err(CliError::ValidationFailed(
812 "GATE-GPUTRAIN-002: --device cuda was requested but this `apr` \
813 binary was built WITHOUT the `cuda` feature. \
814 Rebuild with `cargo build --release --features cuda` or use \
815 `--device cpu`. See memory/feedback_cuda_feature_footgun.md \
816 (contract gpu-training-backend-v1 / task #132 Phase 2)."
817 .into(),
818 ))
819}
820
821fn run_and_report<S: StepFn, V: ValFn>(
826 config: PretrainConfig,
827 step_fn: S,
828 val_fn: V,
829 checkpoint_fn: Option<Box<dyn CheckpointFn>>,
830 json_output: bool,
831) -> Result<RunStatus> {
832 let mut loop_ = PretrainLoop::new(config, step_fn, val_fn);
833 if let Some(ckpt) = checkpoint_fn {
834 loop_ = loop_.with_checkpoint_fn(ckpt);
835 }
836 let status = loop_.run();
837 report(&status, &loop_, json_output)?;
838 Ok(status)
839}
840
841fn abort_to_err(abort: &PretrainAbort) -> CliError {
842 match abort {
843 PretrainAbort::Divergence { .. } | PretrainAbort::DivergenceAtEpochZero { .. } => {
844 CliError::ValidationFailed(format!(
845 "GATE-TRAIN-005 ship-blocker fired: {abort}. See \
846 contracts/training-loop-pretrain-v1.yaml and \
847 memory/project_ship_two_001_model1_qlora_divergence.md"
848 ))
849 }
850 PretrainAbort::NumericalInstability { .. } => {
851 CliError::ValidationFailed(format!("GATE-TRAIN-007 NaN/Inf guard fired: {abort}"))
852 }
853 PretrainAbort::ThroughputOutOfRange { .. } => CliError::ValidationFailed(format!(
854 "GATE-TRAIN-008 throughput-range guard fired: {abort}"
855 )),
856 }
857}
858
859fn print_header(cfg: &PretrainConfig) {
860 output::header("apr pretrain — SHIP-TWO-001 MODEL-2 training loop");
861 println!();
862 output::section("Configuration");
863 output::kv(" Dataset", cfg.dataset_path.display().to_string());
864 output::kv(" Tokenizer", cfg.tokenizer_dir.display().to_string());
865 output::kv(" Run dir", cfg.run_dir.display().to_string());
866 output::kv(" LR max", format!("{:.2e}", cfg.lr_max));
867 output::kv(" Total steps", cfg.total_steps.to_string());
868 output::kv(" Warmup steps", cfg.warmup_steps.to_string());
869 output::kv(
870 " Batch × seq",
871 format!("{} × {}", cfg.batch_size, cfg.seq_length),
872 );
873 output::kv(" Steps / epoch", cfg.steps_per_epoch.to_string());
874 output::kv(" Seed", cfg.seed.to_string());
875 output::kv(" Target val_loss", format!("{:.2}", cfg.target_val_loss));
876 println!();
877}
878
879fn report<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
880 status: &RunStatus,
881 loop_: &PretrainLoop<S, V>,
882 json_output: bool,
883) -> Result<()> {
884 if json_output {
885 let report = PretrainReport::from(status, loop_);
886 let json = serde_json::to_string_pretty(&report)
887 .map_err(|e| CliError::InvalidFormat(e.to_string()))?;
888 println!("{json}");
889 return Ok(());
890 }
891
892 output::section("Run Result");
893 match status {
894 RunStatus::Ok {
895 final_val_loss,
896 epochs_completed,
897 } => {
898 println!(
899 " {} CONVERGED final val_loss={:.4} after {} epoch(s)",
900 "OK".green().bold(),
901 final_val_loss,
902 epochs_completed
903 );
904 }
905 RunStatus::EarlyStop {
906 best_val_loss,
907 epochs_completed,
908 } => {
909 println!(
910 " {} EARLY_STOP best val_loss={:.4} after {} epoch(s)",
911 "OK".yellow().bold(),
912 best_val_loss,
913 epochs_completed
914 );
915 }
916 RunStatus::Aborted(abort) => {
917 println!(" {} ABORTED {}", "FAIL".red().bold(), abort);
918 }
919 }
920 output::kv(" Steps recorded", loop_.step_metrics().len().to_string());
921 output::kv(
922 " Epochs recorded",
923 loop_.epoch_artifacts().len().to_string(),
924 );
925 println!();
926 Ok(())
927}
928
929#[derive(serde::Serialize)]
930struct PretrainReport {
931 status: String,
932 detail: Option<String>,
933 final_val_loss: Option<f32>,
934 epochs_completed: usize,
935 steps_recorded: usize,
936 val_loss_history: Vec<f32>,
937 per_step_metrics: Vec<entrenar::train::pretrain::StepMetrics>,
947}
948
949impl PretrainReport {
950 fn from<S: entrenar::train::pretrain::StepFn, V: entrenar::train::pretrain::ValFn>(
951 status: &RunStatus,
952 loop_: &PretrainLoop<S, V>,
953 ) -> Self {
954 let (status_name, detail, final_val_loss, epochs_completed) = match status {
955 RunStatus::Ok {
956 final_val_loss,
957 epochs_completed,
958 } => (
959 "OK".to_string(),
960 None,
961 Some(*final_val_loss),
962 *epochs_completed,
963 ),
964 RunStatus::EarlyStop {
965 best_val_loss,
966 epochs_completed,
967 } => (
968 "EARLY_STOP".to_string(),
969 None,
970 Some(*best_val_loss),
971 *epochs_completed,
972 ),
973 RunStatus::Aborted(abort) => (
974 "ABORTED".to_string(),
975 Some(abort.to_string()),
976 None,
977 loop_.epoch_artifacts().len(),
978 ),
979 };
980 PretrainReport {
981 status: status_name,
982 detail,
983 final_val_loss,
984 epochs_completed,
985 steps_recorded: loop_.step_metrics().len(),
986 val_loss_history: loop_.val_loss_history().to_vec(),
987 per_step_metrics: loop_.step_metrics().to_vec(),
988 }
989 }
990}
991
992#[cfg(test)]
993mod tests {
994 use super::*;
995 use tempfile::TempDir;
996
997 #[test]
1001 fn checkpoint_name_and_arch_default_when_no_init() {
1002 let (name, arch) = checkpoint_name_and_arch(None);
1003 assert_eq!(name, "llama-370m-pretrain");
1004 assert_eq!(arch, "LlamaForCausalLM");
1005 }
1006
1007 #[test]
1011 fn checkpoint_name_and_arch_qwen2_init() {
1012 let mut cfg = TransformerConfig::llama2_7b();
1013 cfg.hf_architecture = Some("Qwen2ForCausalLM".to_string());
1014 cfg.hf_model_type = Some("qwen2".to_string());
1015 let (name, arch) = checkpoint_name_and_arch(Some(&cfg));
1016 assert_eq!(name, "qwen2-pretrain");
1017 assert_eq!(arch, "Qwen2ForCausalLM");
1018 }
1019
1020 #[test]
1025 fn checkpoint_name_and_arch_init_without_hf_fields() {
1026 let cfg = TransformerConfig::llama2_7b();
1027 let (name, arch) = checkpoint_name_and_arch(Some(&cfg));
1029 assert_eq!(name, "model-pretrain");
1030 assert_eq!(arch, "LlamaForCausalLM");
1031 }
1032
1033 fn stage_vocab_json(dir: &std::path::Path, n: usize) {
1038 std::fs::create_dir_all(dir).expect("mkdir tokenizer dir");
1039 let mut obj = serde_json::Map::with_capacity(n);
1040 for i in 0..n {
1041 obj.insert(format!("t{i}"), serde_json::Value::from(i as u64));
1042 }
1043 let json = serde_json::to_string(&obj).expect("serialize");
1044 std::fs::write(dir.join("vocab.json"), json).expect("write vocab.json");
1045 }
1046
1047 #[test]
1051 fn estimate_param_count_qwen2_05b_within_2x() {
1052 let mut cfg = TransformerConfig::llama2_7b();
1053 cfg.hidden_size = 896;
1054 cfg.num_hidden_layers = 24;
1055 cfg.num_attention_heads = 14;
1056 cfg.num_kv_heads = 2;
1057 cfg.intermediate_size = 4864;
1058 cfg.vocab_size = 151936;
1059 let n = estimate_param_count(&cfg);
1060 let ref_params: u64 = 494_000_000;
1063 assert!(
1064 n > ref_params / 2 && n < ref_params * 2,
1065 "Qwen2.5-0.5B estimate {n} should be within 2× of 494M",
1066 );
1067 }
1068
1069 #[test]
1071 fn estimate_param_count_scales_with_layers() {
1072 let mut cfg = TransformerConfig::llama2_7b();
1073 cfg.hidden_size = 512;
1074 cfg.num_hidden_layers = 1;
1075 cfg.intermediate_size = 2048;
1076 cfg.vocab_size = 32000;
1077 let n1 = estimate_param_count(&cfg);
1078 cfg.num_hidden_layers = 24;
1079 let n24 = estimate_param_count(&cfg);
1080 assert!(
1083 n24 > n1 * 4,
1084 "24-layer model {n24} should be at least 4× 1-layer model {n1}",
1085 );
1086 }
1087
1088 fn chinchilla_gate_check(
1101 arch: &TransformerConfig,
1102 num_steps: usize,
1103 batch_size: usize,
1104 seq_length: usize,
1105 force_under_provisioned: bool,
1106 ) -> Option<f64> {
1107 let n_params = estimate_param_count(arch);
1108 let d_tokens = (num_steps as u64)
1109 .saturating_mul(batch_size as u64)
1110 .saturating_mul(seq_length as u64);
1111 let ratio = d_tokens as f64 / n_params as f64;
1112 if ratio < 10.0 && !force_under_provisioned {
1113 Some(ratio)
1114 } else {
1115 None
1116 }
1117 }
1118
1119 fn qwen_05b_config() -> TransformerConfig {
1120 let mut cfg = TransformerConfig::llama2_7b();
1121 cfg.hidden_size = 896;
1122 cfg.num_hidden_layers = 24;
1123 cfg.num_attention_heads = 14;
1124 cfg.num_kv_heads = 2;
1125 cfg.intermediate_size = 4864;
1126 cfg.vocab_size = 151936;
1127 cfg.hf_architecture = Some("Qwen2ForCausalLM".to_string());
1128 cfg.hf_model_type = Some("qwen2".to_string());
1129 cfg
1130 }
1131
1132 #[test]
1136 fn chinchilla_hard_gate_rejects_under_provisioned() {
1137 let cfg = qwen_05b_config();
1138 let verdict = chinchilla_gate_check(&cfg, 5000, 16, 512, false);
1139 assert!(verdict.is_some(), "0.083× should be rejected");
1140 let ratio = verdict.expect("ratio");
1141 assert!(ratio < 0.1, "expected ratio < 0.1, got {ratio}");
1142 }
1143
1144 #[test]
1147 fn chinchilla_hard_gate_bypasses_with_force_flag() {
1148 let cfg = qwen_05b_config();
1149 let verdict = chinchilla_gate_check(&cfg, 5000, 16, 512, true);
1150 assert!(verdict.is_none(), "force_under_provisioned must bypass");
1151 }
1152
1153 #[test]
1158 fn chinchilla_hard_gate_boundary_10x() {
1159 let cfg = qwen_05b_config();
1160 let n = estimate_param_count(&cfg);
1161 let bs = 16u64;
1162 let sl = 512u64;
1163 let target_d = 10 * n;
1164 let bs_sl = bs * sl;
1165 let exact_steps = (target_d + bs_sl - 1) / bs_sl;
1167 let verdict_exact =
1168 chinchilla_gate_check(&cfg, exact_steps as usize, bs as usize, sl as usize, false);
1169 assert!(
1170 verdict_exact.is_none(),
1171 "ratio ≥ 10.0 should PASS, got verdict={verdict_exact:?}"
1172 );
1173 let verdict_below = chinchilla_gate_check(
1175 &cfg,
1176 (exact_steps - 1) as usize,
1177 bs as usize,
1178 sl as usize,
1179 false,
1180 );
1181 assert!(
1182 verdict_below.is_some(),
1183 "ratio just below 10× should be REJECTED"
1184 );
1185 }
1186
1187 #[test]
1190 fn chinchilla_hard_gate_accepts_well_provisioned() {
1191 let cfg = qwen_05b_config();
1192 let n = estimate_param_count(&cfg);
1193 let bs = 16u64;
1195 let sl = 512u64;
1196 let steps_25x = ((25 * n) / (bs * sl)) as usize;
1197 let verdict = chinchilla_gate_check(&cfg, steps_25x, bs as usize, sl as usize, false);
1198 assert!(verdict.is_none(), "25× should pass");
1199 }
1200
1201 #[test]
1202 fn preflight_accepts_matching_vocab() {
1203 let tmp = TempDir::new().expect("tempdir");
1206 stage_vocab_json(tmp.path(), Llama370MConfig::VOCAB_SIZE);
1207 preflight_tokenizer_vocab_matches_target(tmp.path(), Llama370MConfig::VOCAB_SIZE, false)
1208 .expect("matching vocab must pass GATE-ARCH-370M-011");
1209 }
1210
1211 #[test]
1212 fn preflight_rejects_tokenizer_vocab_mismatch() {
1213 let tmp = TempDir::new().expect("tempdir");
1220 let mismatch = Llama370MConfig::VOCAB_SIZE - 1;
1221 stage_vocab_json(tmp.path(), mismatch);
1222 let err = preflight_tokenizer_vocab_matches_target(
1223 tmp.path(),
1224 Llama370MConfig::VOCAB_SIZE,
1225 false,
1226 )
1227 .expect_err("tokenizer/model vocab mismatch must be rejected");
1228 match err {
1229 CliError::ValidationFailed(msg) => {
1230 assert!(
1231 msg.contains("GATE-ARCH-370M-011"),
1232 "msg must cite gate: {msg}"
1233 );
1234 assert!(
1235 msg.contains(&mismatch.to_string()),
1236 "msg must name tokenizer vocab: {msg}"
1237 );
1238 assert!(
1239 msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
1240 "msg must name model vocab: {msg}"
1241 );
1242 }
1243 other => panic!("unexpected error: {other:?}"),
1244 }
1245 }
1246
1247 #[test]
1248 fn preflight_rejects_missing_vocab_json() {
1249 let tmp = TempDir::new().expect("tempdir");
1253 let err = preflight_tokenizer_vocab_matches_target(
1254 tmp.path(),
1255 Llama370MConfig::VOCAB_SIZE,
1256 false,
1257 )
1258 .expect_err("missing vocab.json must be rejected");
1259 match err {
1260 CliError::ValidationFailed(msg) => {
1261 assert!(
1262 msg.contains("GATE-ARCH-370M-011"),
1263 "msg must cite gate: {msg}"
1264 );
1265 assert!(
1266 msg.contains("cannot read"),
1267 "msg must name I/O failure: {msg}"
1268 );
1269 }
1270 other => panic!("unexpected error: {other:?}"),
1271 }
1272 }
1273
1274 #[test]
1281 fn preflight_qwen_vocab_passes_with_qwen_target() {
1282 const QWEN2_VOCAB_SIZE: usize = 151_936;
1283 let tmp = TempDir::new().expect("tempdir");
1284 stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
1285 preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN2_VOCAB_SIZE, true).expect(
1289 "Qwen tokenizer (151_936) MUST pass preflight when target is Qwen-shaped — \
1290 this is the load-bearing claim of §49 fine-tune from a Qwen2.5 init checkpoint",
1291 );
1292 }
1293
1294 #[test]
1301 fn preflight_qwen_vocab_fails_with_llama_target() {
1302 const QWEN2_VOCAB_SIZE: usize = 151_936;
1303 let tmp = TempDir::new().expect("tempdir");
1304 stage_vocab_json(tmp.path(), QWEN2_VOCAB_SIZE);
1305 let err = preflight_tokenizer_vocab_matches_target(
1308 tmp.path(),
1309 Llama370MConfig::VOCAB_SIZE,
1310 false,
1311 )
1312 .expect_err(
1313 "Qwen tokenizer (151_936) MUST FAIL preflight when target is Llama370M (50_257) — \
1314 silent-pass would corrupt training",
1315 );
1316 match err {
1317 CliError::ValidationFailed(msg) => {
1318 assert!(
1319 msg.contains(&QWEN2_VOCAB_SIZE.to_string()),
1320 "msg must name Qwen vocab size 151_936: {msg}"
1321 );
1322 assert!(
1323 msg.contains(&Llama370MConfig::VOCAB_SIZE.to_string()),
1324 "msg must name target Llama vocab size 50_257: {msg}"
1325 );
1326 }
1327 other => panic!("unexpected error: {other:?}"),
1328 }
1329 }
1330
1331 #[test]
1336 fn preflight_qwen_reserved_slots_pass_under_polymorphic_init() {
1337 const QWEN_TOKENIZER_EFFECTIVE: usize = 151_665;
1338 const QWEN_DECLARED_VOCAB: usize = 151_936;
1339 let tmp = TempDir::new().expect("tempdir");
1340 stage_vocab_json(tmp.path(), QWEN_TOKENIZER_EFFECTIVE);
1341
1342 preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, true).expect(
1344 "FALSIFY-APR-PRETRAIN-ARCH-009: HF reserved-slot tokenizer (151_665 ≤ 151_936) \
1345 MUST pass preflight under polymorphic init path (§55 relaxed bound)",
1346 );
1347
1348 let err = preflight_tokenizer_vocab_matches_target(tmp.path(), QWEN_DECLARED_VOCAB, false)
1350 .expect_err(
1351 "FALSIFY-APR-PRETRAIN-ARCH-009 dual: from-scratch path MUST keep strict ==",
1352 );
1353 match err {
1354 CliError::ValidationFailed(msg) => {
1355 assert!(
1356 msg.contains("GATE-ARCH-370M-011")
1357 && msg.contains(&QWEN_TOKENIZER_EFFECTIVE.to_string())
1358 && msg.contains(&QWEN_DECLARED_VOCAB.to_string()),
1359 "strict-mode error must name gate + both sizes: {msg}"
1360 );
1361 }
1362 other => panic!("unexpected error: {other:?}"),
1363 }
1364 }
1365
1366 #[test]
1371 fn preflight_oversized_tokenizer_rejected_even_under_polymorphic_init() {
1372 const QWEN_DECLARED_VOCAB: usize = 151_936;
1373 let oversized = QWEN_DECLARED_VOCAB + 100;
1374 let tmp = TempDir::new().expect("tempdir");
1375 stage_vocab_json(tmp.path(), oversized);
1376
1377 let err = preflight_tokenizer_vocab_matches_target(
1378 tmp.path(),
1379 QWEN_DECLARED_VOCAB,
1380 true, )
1382 .expect_err(
1383 "FALSIFY-APR-PRETRAIN-ARCH-010: oversized tokenizer MUST fail-fast even under \
1384 polymorphic init (OOB safety; relaxed bound is ≤ not <)",
1385 );
1386 match err {
1387 CliError::ValidationFailed(msg) => {
1388 assert!(
1389 msg.contains("RELAXED") && msg.contains("OOB"),
1390 "polymorphic-mode error must cite RELAXED + OOB: {msg}"
1391 );
1392 }
1393 other => panic!("unexpected error: {other:?}"),
1394 }
1395 }
1396
1397 #[test]
1415 fn drive_real_cuda_init_path_wireup_sentinel_pinned() {
1416 let msg = FALSIFY_APR_PRETRAIN_INIT_CUDA_001_MSG;
1417 assert!(
1418 msg.contains("FALSIFY-APR-PRETRAIN-INIT-CUDA-001"),
1419 "sentinel MUST cite the falsifier id (auditability): {msg}"
1420 );
1421 assert!(
1422 msg.contains("is wired for --device cuda"),
1423 "sentinel MUST contain the canonical 'is wired' phrase so \
1424 operators recognize §50.4 step 5f.5 SHIPPED: {msg}"
1425 );
1426 assert!(
1427 msg.contains("build_shared_cuda_trainer_with_init"),
1428 "sentinel MUST name the symmetric builder so future agents \
1429 know which symbol implements the wireup: {msg}"
1430 );
1431 assert!(
1432 msg.contains("5f.5 SHIPPED"),
1433 "sentinel MUST include the 5f.5 SHIPPED status marker so \
1434 grep over the codebase can find the discharge point: {msg}"
1435 );
1436 }
1437
1438 #[test]
1439 fn synthetic_pretrain_end_to_end_happy_path() {
1440 let tmp = TempDir::new().expect("tempdir");
1441 let dataset = tmp.path().join("data.jsonl");
1442 let tokenizer = tmp.path().join("tok");
1443 let run_dir = tmp.path().join("run");
1444
1445 let result = run(
1446 &dataset,
1447 &tokenizer,
1448 &run_dir,
1449 PretrainMode::Finetune,
1450 Some(5.0e-5),
1451 25,
1452 Some(5),
1453 2,
1454 4,
1455 5,
1456 42,
1457 Some(2.2),
1458 50257,
1459 true,
1460 "cpu",
1461 None,
1462 false,
1463 None,
1464 true,
1465 );
1466 assert!(
1467 result.is_ok(),
1468 "synthetic pretrain end-to-end must succeed: got {result:?}"
1469 );
1470 }
1471
1472 #[test]
1473 fn real_mode_empty_dataset_dir_errors() {
1474 let tmp = TempDir::new().expect("tempdir");
1480 let tok_dir = tmp.path().join("tok");
1481 stage_vocab_json(&tok_dir, Llama370MConfig::VOCAB_SIZE);
1482 let err = run(
1483 tmp.path(),
1484 &tok_dir,
1485 tmp.path(),
1486 PretrainMode::Finetune,
1487 Some(5.0e-5),
1488 10,
1489 Some(2),
1490 2,
1491 4,
1492 5,
1493 42,
1494 Some(2.2),
1495 50257,
1496 false,
1497 "cpu",
1498 None,
1499 false,
1500 None,
1501 true,
1502 )
1503 .expect_err("empty dataset dir must fail to initialise the shard iterator");
1504 match err {
1505 CliError::ValidationFailed(msg) => {
1506 assert!(
1507 msg.contains("shard iterator init failed"),
1508 "unexpected message: {msg}"
1509 );
1510 }
1511 other => panic!("unexpected error: {other:?}"),
1512 }
1513 }
1514
1515 #[test]
1516 fn invalid_target_val_loss_rejected() {
1517 let tmp = TempDir::new().expect("tempdir");
1518 let err = run(
1519 tmp.path(),
1520 tmp.path(),
1521 tmp.path(),
1522 PretrainMode::Finetune,
1523 Some(5.0e-5),
1524 10,
1525 Some(2),
1526 2,
1527 4,
1528 5,
1529 42,
1530 Some(-1.0),
1531 50257,
1532 true,
1533 "cpu",
1534 None,
1535 false,
1536 None,
1537 true,
1538 )
1539 .expect_err("negative target_val_loss must be rejected");
1540 assert!(matches!(err, CliError::ValidationFailed(_)));
1541 }
1542
1543 #[test]
1552 fn mode_finetune_is_default_and_matches_contract() {
1553 let hp = mode_defaults(PretrainMode::Finetune, 50257, None, None, None);
1557 assert_eq!(hp.regime, TrainingRegime::Finetune);
1558 assert!(
1559 (hp.lr_max - 5.0e-5).abs() < 1.0e-12,
1560 "lr_max={} must equal finetune default 5e-5",
1561 hp.lr_max
1562 );
1563 assert_eq!(hp.warmup_steps, 100);
1564 assert!(
1565 (hp.target_val_loss - 2.2).abs() < 1.0e-6,
1566 "target_val_loss={} must equal finetune default 2.2",
1567 hp.target_val_loss
1568 );
1569 }
1570
1571 #[test]
1572 fn mode_from_scratch_applies_all_four_defaults() {
1573 let hp = mode_defaults(PretrainMode::FromScratch, 50257, None, None, None);
1577 assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1578 assert!(
1579 (hp.lr_max - 3.0e-4).abs() < 1.0e-12,
1580 "lr_max={} must equal from_scratch default 3e-4",
1581 hp.lr_max
1582 );
1583 assert_eq!(hp.warmup_steps, 1000);
1584 assert!(
1585 (hp.target_val_loss - 3.0).abs() < 1.0e-6,
1586 "target_val_loss={} must equal from_scratch default 3.0",
1587 hp.target_val_loss
1588 );
1589 }
1590
1591 #[test]
1592 fn mode_from_scratch_honors_explicit_lr_override() {
1593 let hp = mode_defaults(PretrainMode::FromScratch, 50257, Some(1.0e-4), None, None);
1598 assert_eq!(hp.regime, TrainingRegime::FromScratch { vocab_size: 50257 });
1599 assert!(
1600 (hp.lr_max - 1.0e-4).abs() < 1.0e-12,
1601 "lr_max={} must equal explicit override 1e-4",
1602 hp.lr_max
1603 );
1604 assert_eq!(hp.warmup_steps, 1000);
1606 assert!((hp.target_val_loss - 3.0).abs() < 1.0e-6);
1607 }
1608
1609 fn parse_pretrain_synthetic(extra: &[&str]) -> bool {
1621 let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1626 std::thread::Builder::new()
1627 .stack_size(16 * 1024 * 1024)
1628 .spawn(move || {
1629 use clap::Parser;
1630 let mut argv: Vec<String> = vec![
1631 "apr".to_string(),
1632 "pretrain".to_string(),
1633 "--dataset".to_string(),
1634 "/tmp/_gate_train_010/ds".to_string(),
1635 "--tokenizer".to_string(),
1636 "/tmp/_gate_train_010/tok".to_string(),
1637 "--run-dir".to_string(),
1638 "/tmp/_gate_train_010/run".to_string(),
1639 ];
1640 argv.extend(extra);
1641 let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1642 match *cli.command {
1643 crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1644 synthetic,
1645 ..
1646 }) => synthetic,
1647 other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1648 }
1649 })
1650 .expect("spawn parse thread")
1651 .join()
1652 .expect("parse thread must not panic")
1653 }
1654
1655 #[test]
1656 fn cli_pretrain_defaults_to_real_compute() {
1657 assert!(
1660 !parse_pretrain_synthetic(&[]),
1661 "INV-TRAIN-010: `apr pretrain` (no --synthetic) must parse to synthetic=false"
1662 );
1663 }
1664
1665 #[test]
1666 fn cli_pretrain_synthetic_flag_routes_to_synthetic() {
1667 assert!(
1669 parse_pretrain_synthetic(&["--synthetic"]),
1670 "INV-TRAIN-010: `apr pretrain --synthetic` must parse to synthetic=true"
1671 );
1672 }
1673
1674 fn parse_pretrain_device(extra: &[&str]) -> String {
1685 let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1686 std::thread::Builder::new()
1687 .stack_size(16 * 1024 * 1024)
1688 .spawn(move || {
1689 use clap::Parser;
1690 let mut argv: Vec<String> = vec![
1691 "apr".to_string(),
1692 "pretrain".to_string(),
1693 "--dataset".to_string(),
1694 "/tmp/_gputrain_device/ds".to_string(),
1695 "--tokenizer".to_string(),
1696 "/tmp/_gputrain_device/tok".to_string(),
1697 "--run-dir".to_string(),
1698 "/tmp/_gputrain_device/run".to_string(),
1699 ];
1700 argv.extend(extra);
1701 let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1702 match *cli.command {
1703 crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1704 device, ..
1705 }) => device,
1706 other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1707 }
1708 })
1709 .expect("spawn parse thread")
1710 .join()
1711 .expect("parse thread must not panic")
1712 }
1713
1714 #[test]
1715 fn cli_pretrain_device_defaults_to_auto() {
1716 assert_eq!(
1721 parse_pretrain_device(&[]),
1722 "auto",
1723 "gpu-training-backend-v1 INV-GPUTRAIN-002: default --device must be `auto`",
1724 );
1725 }
1726
1727 #[test]
1728 fn cli_pretrain_device_accepts_cpu() {
1729 assert_eq!(parse_pretrain_device(&["--device", "cpu"]), "cpu");
1731 }
1732
1733 #[test]
1734 fn cli_pretrain_device_accepts_cuda_index() {
1735 assert_eq!(parse_pretrain_device(&["--device", "cuda:7"]), "cuda:7");
1738 }
1739
1740 fn parse_pretrain_init(extra: &[&str]) -> Option<std::path::PathBuf> {
1753 let extra: Vec<String> = extra.iter().map(|s| (*s).to_string()).collect();
1754 std::thread::Builder::new()
1755 .stack_size(16 * 1024 * 1024)
1756 .spawn(move || {
1757 use clap::Parser;
1758 let mut argv: Vec<String> = vec![
1759 "apr".to_string(),
1760 "pretrain".to_string(),
1761 "--dataset".to_string(),
1762 "/tmp/_init_flag/ds".to_string(),
1763 "--tokenizer".to_string(),
1764 "/tmp/_init_flag/tok".to_string(),
1765 "--run-dir".to_string(),
1766 "/tmp/_init_flag/run".to_string(),
1767 ];
1768 argv.extend(extra);
1769 let cli = crate::Cli::try_parse_from(&argv).expect("clap parse must succeed");
1770 match *cli.command {
1771 crate::Commands::Extended(crate::ExtendedCommands::Pretrain {
1772 init, ..
1773 }) => init,
1774 other => panic!("expected ExtendedCommands::Pretrain, got {other:?}"),
1775 }
1776 })
1777 .expect("spawn parse thread")
1778 .join()
1779 .expect("parse thread must not panic")
1780 }
1781
1782 #[test]
1784 fn pretrain_init_flag_absent_parses_to_none() {
1785 assert_eq!(
1788 parse_pretrain_init(&[]),
1789 None,
1790 "FALSIFY-APR-PRETRAIN-INIT-001/002: default --init must be None (no silent default)"
1791 );
1792 }
1793
1794 #[test]
1796 fn pretrain_init_flag_parses_path() {
1797 let parsed = parse_pretrain_init(&["--init", "/tmp/foo.apr"]);
1798 assert_eq!(
1799 parsed.as_deref().and_then(|p| p.to_str()),
1800 Some("/tmp/foo.apr"),
1801 "FALSIFY-APR-PRETRAIN-INIT-001: --init <PATH> must round-trip through clap"
1802 );
1803 }
1804
1805 #[test]
1808 fn pretrain_init_missing_file_errors() {
1809 let tmp = TempDir::new().expect("tempdir");
1810 let missing = tmp.path().join("does-not-exist.apr");
1811 let err = run(
1812 tmp.path(),
1813 tmp.path(),
1814 tmp.path(),
1815 PretrainMode::Finetune,
1816 Some(5.0e-5),
1817 10,
1818 Some(2),
1819 2,
1820 4,
1821 5,
1822 42,
1823 Some(2.2),
1824 50257,
1825 true,
1826 "cpu",
1827 Some(&missing),
1828 false,
1829 None,
1830 true,
1831 )
1832 .expect_err("missing --init file must be rejected");
1833 match err {
1834 CliError::ValidationFailed(msg) => {
1835 assert!(
1836 msg.contains("FALSIFY-APR-PRETRAIN-INIT-003"),
1837 "msg must cite falsifier id: {msg}"
1838 );
1839 assert!(
1840 msg.contains("does-not-exist.apr"),
1841 "msg must name the missing path: {msg}"
1842 );
1843 }
1844 other => panic!("unexpected error: {other:?}"),
1845 }
1846 }
1847
1848 #[test]
1850 fn pretrain_init_bad_magic_errors() {
1851 let tmp = TempDir::new().expect("tempdir");
1852 let bad = tmp.path().join("not-an-apr.bin");
1853 std::fs::write(&bad, b"GGUF\x00\x00\x00\x00\x00\x00\x00\x00").expect("write fixture file");
1854 let err = run(
1855 tmp.path(),
1856 tmp.path(),
1857 tmp.path(),
1858 PretrainMode::Finetune,
1859 Some(5.0e-5),
1860 10,
1861 Some(2),
1862 2,
1863 4,
1864 5,
1865 42,
1866 Some(2.2),
1867 50257,
1868 true,
1869 "cpu",
1870 Some(&bad),
1871 false,
1872 None,
1873 true,
1874 )
1875 .expect_err("invalid magic bytes must be rejected");
1876 match err {
1877 CliError::ValidationFailed(msg) => {
1878 assert!(
1879 msg.contains("FALSIFY-APR-PRETRAIN-INIT-004"),
1880 "msg must cite falsifier id: {msg}"
1881 );
1882 assert!(
1883 msg.contains("not a valid APR file"),
1884 "msg must describe magic mismatch: {msg}"
1885 );
1886 }
1887 other => panic!("unexpected error: {other:?}"),
1888 }
1889 }
1890
1891 #[test]
1893 fn pretrain_init_empty_file_errors() {
1894 let tmp = TempDir::new().expect("tempdir");
1895 let empty = tmp.path().join("empty.apr");
1896 std::fs::write(&empty, b"").expect("write empty fixture");
1897 let err = run(
1898 tmp.path(),
1899 tmp.path(),
1900 tmp.path(),
1901 PretrainMode::Finetune,
1902 Some(5.0e-5),
1903 10,
1904 Some(2),
1905 2,
1906 4,
1907 5,
1908 42,
1909 Some(2.2),
1910 50257,
1911 true,
1912 "cpu",
1913 Some(&empty),
1914 false,
1915 None,
1916 true,
1917 )
1918 .expect_err("empty file must be rejected (cannot contain magic bytes)");
1919 assert!(matches!(err, CliError::ValidationFailed(_)));
1920 }
1921
1922 #[test]
1929 fn pretrain_init_valid_magic_but_bogus_metadata_fails_at_arch_extraction() {
1930 let tmp = TempDir::new().expect("tempdir");
1931 let valid = tmp.path().join("v2-valid-magic-bogus-metadata.apr");
1932 std::fs::write(&valid, b"APR\x00\x00\x00\x00\x00\x00\x00\x00\x00")
1935 .expect("write fixture file");
1936 let err = run(
1937 tmp.path(),
1938 tmp.path(),
1939 tmp.path(),
1940 PretrainMode::Finetune,
1941 Some(5.0e-5),
1942 10,
1943 Some(2),
1944 2,
1945 4,
1946 5,
1947 42,
1948 Some(2.2),
1949 50257,
1950 true,
1951 "cpu",
1952 Some(&valid),
1953 false,
1954 None,
1955 true,
1956 )
1957 .expect_err("bogus metadata must NOT silently random-init");
1958 match err {
1959 CliError::ValidationFailed(msg) => {
1960 assert!(
1961 !msg.contains("not yet wired"),
1962 "the legacy step-5-partial guard must be retired: {msg}"
1963 );
1964 }
1968 other => panic!("unexpected error: {other:?}"),
1969 }
1970 }
1971
1972 #[test]
1976 fn pretrain_init_v1_magic_aprn_passes_validate_init_apr_path() {
1977 let tmp = TempDir::new().expect("tempdir");
1978 let v1 = tmp.path().join("v1-aprn.apr");
1979 std::fs::write(&v1, b"APRN\x00\x00\x00\x00").expect("write fixture file");
1980 let result = validate_init_apr_path(&v1);
1981 assert!(
1982 result.is_ok(),
1983 "APRN magic must pass validate_init_apr_path; got {result:?}"
1984 );
1985 }
1986}