1use crate::models::llama_370m::Llama370MConfig;
23use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
24use crate::train::transformer_trainer::{LMBatch, TransformerTrainConfig, TransformerTrainer};
25use crate::transformer::{ModelArchitecture, Transformer, TransformerConfig};
26use crate::Tensor;
27use std::cell::RefCell;
28use std::collections::BTreeMap;
29use std::path::Path;
30use std::rc::Rc;
31
32pub type SharedTrainer = Rc<RefCell<TransformerTrainer>>;
36
37pub fn load_init_tensors_from_apr(
65 path: impl AsRef<Path>,
66) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String> {
67 let path_ref = path.as_ref();
68 aprender::format::converter::load_model_tensors(path_ref).map_err(|e| {
69 format!(
70 "FALSIFY-APR-PRETRAIN-INIT-006: failed to load init tensors from APR file {}: {e}",
71 path_ref.display()
72 )
73 })
74}
75
76pub fn validate_pretrain_init_arch_compatible(cfg: &TransformerConfig) -> Result<(), String> {
98 match cfg.architecture {
99 ModelArchitecture::Decoder => Ok(()),
100 ModelArchitecture::Encoder => Err(format!(
101 "FALSIFY-APR-PRETRAIN-ARCH-007: --init checkpoint has architecture=Encoder \
102 (e.g., BERT/RoBERTa/CodeBERT) but the pretrain trainer is decoder-only \
103 (Llama/Qwen-class causal LMs). Loading encoder weights into a decoder \
104 trainer would produce nonsense gradients. Architectural details: \
105 hidden_size={}, num_layers={}, vocab_size={}, hf_architecture={:?}",
106 cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size, cfg.hf_architecture
107 )),
108 }
109}
110
111#[must_use]
123pub fn family_from_tensor_names<'a, I>(names: I) -> &'static str
124where
125 I: IntoIterator<Item = &'a str>,
126{
127 let names: Vec<&str> = names.into_iter().collect();
131
132 let any_contains = |needle: &str| names.iter().any(|k| k.contains(needle));
133 let any_starts_with = |pfx: &str| names.iter().any(|k| k.starts_with(pfx));
134
135 if any_contains("mixer.in_proj") || any_contains("mixer.out_proj") {
137 return "mamba";
138 }
139 if any_starts_with("rwkv.blocks.") || any_contains("blocks.0.att.") {
141 return "rwkv";
142 }
143 if any_starts_with("gpt_neox.") {
145 return "gpt-neox";
146 }
147 if any_starts_with("model.decoder.layers.") {
149 return "opt";
150 }
151 if any_starts_with("bert.") {
153 return "bert";
154 }
155 let has_model_layers = any_contains("model.layers");
156 let has_transformer_h = any_contains("transformer.h")
157 || names.iter().any(|k| k.starts_with("h.") && k.contains(".attn."));
158 let has_blk = any_contains("blk.");
159 if has_model_layers {
160 if any_contains("self_attn.q_norm.weight") {
162 return "qwen3";
163 }
164 if any_contains("self_attn.q_proj.bias") || any_contains("qkv_proj.weight") {
166 return "qwen2";
167 }
168 return "llama";
169 }
170 if has_transformer_h {
171 return "gpt2";
172 }
173 if has_blk {
174 return "unknown"; }
176 "unknown"
177}
178
179#[must_use]
195pub fn normalize_metadata_arch_family(arch: &str) -> Option<&'static str> {
196 match arch {
197 "Qwen2ForCausalLM" | "Qwen2.5ForCausalLM" => Some("qwen2"),
199 "Qwen3ForCausalLM" | "Qwen3MoeForCausalLM" => Some("qwen3"),
200 "LlamaForCausalLM" => Some("llama"),
201 "MistralForCausalLM" => Some("llama"), "Phi3ForCausalLM" | "PhiForCausalLM" => Some("llama"), "GPT2LMHeadModel" => Some("gpt2"),
204 "GPTNeoXForCausalLM" => Some("gpt-neox"),
205 "MambaForCausalLM" => Some("mamba"),
206 "RwkvForCausalLM" | "Rwkv6ForCausalLM" => Some("rwkv"),
207 "BertModel" | "BertForMaskedLM" => Some("bert"),
208 "OPTForCausalLM" => Some("opt"),
209 "qwen2" | "qwen2.5" | "qwen" => Some("qwen2"),
211 "qwen3" | "qwen3_5" | "qwen3.5" => Some("qwen3"),
212 "llama" | "mistral" | "phi" | "phi3" | "phi4" => Some("llama"),
213 "gpt2" => Some("gpt2"),
214 "gpt-neox" | "gpt_neox" | "gptneox" | "pythia" => Some("gpt-neox"),
215 "mamba" => Some("mamba"),
216 "rwkv" => Some("rwkv"),
217 "bert" => Some("bert"),
218 "opt" => Some("opt"),
219 "Qwen2" | "Qwen2.5" | "Qwen" => Some("qwen2"),
221 "Qwen3" => Some("qwen3"),
222 "Llama" | "Mistral" | "Phi" | "Phi3" | "Phi4" => Some("llama"),
223 "Gpt2" | "GPT2" => Some("gpt2"),
224 _ => None,
226 }
227}
228
229pub fn validate_init_arch_matches_tensor_evidence(
263 metadata_arch: Option<&str>,
264 init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
265) -> Result<(), String> {
266 let Some(metadata_family) = metadata_arch.and_then(normalize_metadata_arch_family) else {
269 return Ok(());
270 };
271
272 let tensor_family = family_from_tensor_names(init_tensors.keys().map(String::as_str));
273
274 if tensor_family == "unknown" {
278 return Ok(());
279 }
280
281 if metadata_family != tensor_family {
282 return Err(format!(
283 "FALSIFY-INIT-ARCH-MATCH-001: --init APR metadata claims architecture \
284 family `{metadata_family}` (from `{}`) but tensor naming implies \
285 family `{tensor_family}`. This is the SPEC §86 silent-failure pattern: \
286 pre-P0-K APRs with the §82 P0-H \"LlamaForCausalLM\" fallback stamp + \
287 Qwen2 tensors load as random-init and train from scratch. Salvage with \
288 `apr stamp <input.apr> --architecture {tensor_family} --hf-architecture \
289 {} -o <stamped.apr>` (see PR #1757 / SPEC §86.4) then re-run \
290 `apr pretrain --init <stamped.apr>`.",
291 metadata_arch.unwrap_or("?"),
292 match tensor_family {
294 "qwen2" => "Qwen2ForCausalLM",
295 "qwen3" => "Qwen3ForCausalLM",
296 "llama" => "LlamaForCausalLM",
297 "gpt2" => "GPT2LMHeadModel",
298 "gpt-neox" => "GPTNeoXForCausalLM",
299 "mamba" => "MambaForCausalLM",
300 "rwkv" => "RwkvForCausalLM",
301 "bert" => "BertModel",
302 "opt" => "OPTForCausalLM",
303 other => other,
304 }
305 ));
306 }
307
308 Ok(())
309}
310
311pub fn populate_trainer_from_init_tensors(
342 transformer: &mut Transformer,
343 init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
344) -> Result<usize, String> {
345 let expected: Vec<(String, usize)> =
346 transformer.named_parameters().into_iter().map(|(name, t)| (name, t.len())).collect();
347 let mut populated = 0usize;
348 let mut errors: Vec<String> = Vec::new();
349
350 for (name, expected_len) in &expected {
351 match init_tensors.get(name) {
352 Some((data, _shape)) => {
353 if data.len() != *expected_len {
354 errors.push(format!(
355 "{name}: init length {} != trainer expected {expected_len}",
356 data.len()
357 ));
358 continue;
359 }
360 let tensor = Tensor::from_vec(data.clone(), true);
361 if !transformer.set_named_parameter(name, tensor) {
362 errors.push(format!("{name}: set_named_parameter rejected the assignment"));
363 continue;
364 }
365 populated += 1;
366 }
367 None => {
368 errors.push(format!("{name}: not present in init APR tensors"));
369 }
370 }
371 }
372
373 if !errors.is_empty() {
374 let total = errors.len();
375 let head = errors.iter().take(5).cloned().collect::<Vec<_>>().join("; ");
376 return Err(format!(
377 "FALSIFY-APR-PRETRAIN-INIT-007: populate_trainer_from_init_tensors \
378 failed for {total} parameter(s); first {} of {total}: {head}",
379 errors.len().min(5)
380 ));
381 }
382
383 Ok(populated)
384}
385
386pub fn llama_370m_transformer_config() -> TransformerConfig {
389 TransformerConfig {
390 hidden_size: Llama370MConfig::HIDDEN_DIM,
391 num_attention_heads: Llama370MConfig::NUM_HEADS,
392 num_kv_heads: Llama370MConfig::NUM_KV_HEADS,
393 intermediate_size: Llama370MConfig::INTERMEDIATE_DIM,
394 num_hidden_layers: Llama370MConfig::NUM_LAYERS,
395 vocab_size: Llama370MConfig::VOCAB_SIZE,
396 max_position_embeddings: Llama370MConfig::MAX_POSITION_EMBEDDINGS,
397 rms_norm_eps: Llama370MConfig::RMS_NORM_EPS,
398 rope_theta: Llama370MConfig::ROPE_THETA,
399 use_bias: false,
400 head_dim_override: None,
401 architecture: ModelArchitecture::Decoder,
402 hf_architecture: Some("LlamaForCausalLM".into()),
403 hf_model_type: Some("llama".into()),
404 tie_word_embeddings: true,
405 }
406}
407
408pub fn build_transformer_config(init: Option<&TransformerConfig>) -> TransformerConfig {
426 match init {
427 None => llama_370m_transformer_config(),
428 Some(cfg) => cfg.clone(),
429 }
430}
431
432pub fn llama_370m_train_config(lr: f32, seq_length: usize, seed: u64) -> TransformerTrainConfig {
435 let model_cfg = llama_370m_transformer_config();
436 let mut cfg = TransformerTrainConfig::new(model_cfg);
437 cfg.lr = lr;
438 cfg.max_seq_len = seq_length;
439 cfg.seed = seed;
440 cfg
441}
442
443pub struct RealStepFn {
447 trainer: SharedTrainer,
448 batches: Box<dyn Iterator<Item = LMBatch>>,
449}
450
451impl RealStepFn {
452 pub fn new(trainer: SharedTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
453 Self { trainer, batches }
454 }
455}
456
457impl StepFn for RealStepFn {
458 fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
459 let Some(batch) = self.batches.next() else {
465 return (1.0, 1.0);
466 };
467 let mut trainer = self.trainer.borrow_mut();
468 let loss = trainer.train_batch(&batch);
469 let grad_norm = 1.0_f32;
474 (loss, grad_norm)
475 }
476
477 fn optimizer_state_sha256(&self) -> Option<String> {
479 Some(self.trainer.borrow().optimizer_state_sha256())
480 }
481}
482
483pub struct RealValFn {
486 trainer: SharedTrainer,
487 held_out: Vec<LMBatch>,
488}
489
490impl RealValFn {
491 pub fn new(trainer: SharedTrainer, held_out: Vec<LMBatch>) -> Self {
492 Self { trainer, held_out }
493 }
494}
495
496impl ValFn for RealValFn {
497 fn validate(&mut self, _epoch: usize) -> f32 {
498 if self.held_out.is_empty() {
499 return f32::NAN;
500 }
501 let trainer = self.trainer.borrow();
502 let mut total_loss = 0.0_f32;
503 let mut total_items = 0_usize;
504 for batch in &self.held_out {
505 for i in 0..batch.batch_size {
506 let Some(inp) = batch.get_input(i) else {
507 continue;
508 };
509 let Some(tgt) = batch.get_target(i) else {
510 continue;
511 };
512 let (loss_val, _loss_tensor, _logits) = trainer.forward_single(inp, tgt);
513 total_loss += loss_val;
514 total_items += 1;
515 }
516 }
517 if total_items == 0 {
518 f32::NAN
519 } else {
520 total_loss / total_items as f32
521 }
522 }
523}
524
525pub struct AprCheckpointFn {
531 trainer: SharedTrainer,
532 model_name: String,
533 architecture: String,
534}
535
536impl AprCheckpointFn {
537 pub fn new(
538 trainer: SharedTrainer,
539 model_name: impl Into<String>,
540 architecture: impl Into<String>,
541 ) -> Self {
542 Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
543 }
544}
545
546impl CheckpointFn for AprCheckpointFn {
547 fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
548 let trainer = self.trainer.borrow();
549 trainer
550 .save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
551 .map_err(|e| format!("save_apr failed: {e}"))
552 }
553}
554
555pub fn build_shared_trainer(lr: f32, seq_length: usize, seed: u64) -> SharedTrainer {
558 let cfg = llama_370m_train_config(lr, seq_length, seed);
559 let trainer = TransformerTrainer::new(cfg);
560 #[cfg(debug_assertions)]
565 {
566 let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
567 debug_assert!(
568 (366_000_000..=374_000_000).contains(¶m_count),
569 "INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
570 );
571 }
572 Rc::new(RefCell::new(trainer))
573}
574
575pub fn build_shared_trainer_with_init(
602 lr: f32,
603 seq_length: usize,
604 seed: u64,
605 init_arch: Option<&TransformerConfig>,
606 init_path: Option<&Path>,
607) -> Result<SharedTrainer, String> {
608 if init_arch.is_some() != init_path.is_some() {
609 return Err(format!(
610 "build_shared_trainer_with_init: init_arch and init_path must both be Some \
611 or both None (caller bug; init_arch.is_some()={}, init_path.is_some()={})",
612 init_arch.is_some(),
613 init_path.is_some()
614 ));
615 }
616
617 if let Some(cfg) = init_arch {
618 validate_pretrain_init_arch_compatible(cfg)?;
619 }
620
621 let model_cfg = build_transformer_config(init_arch);
622 let mut train_cfg = TransformerTrainConfig::new(model_cfg);
623 train_cfg.lr = lr;
624 train_cfg.max_seq_len = seq_length;
625 train_cfg.seed = seed;
626 let mut trainer = TransformerTrainer::new(train_cfg);
627
628 if let Some(path) = init_path {
635 let tensors = load_init_tensors_from_apr(path)?;
636 let raw_metadata_arch = read_apr_metadata_architecture_string(path);
643 validate_init_arch_matches_tensor_evidence(raw_metadata_arch.as_deref(), &tensors)?;
644 populate_trainer_from_init_tensors(trainer.model_mut(), &tensors)?;
645 }
646
647 Ok(Rc::new(RefCell::new(trainer)))
648}
649
650fn read_apr_metadata_architecture_string(path: &Path) -> Option<String> {
659 use aprender::format::v2::{AprV2Header, AprV2Metadata, HEADER_SIZE_V2, MAGIC_V2};
660 use std::io::{Read, Seek, SeekFrom};
661 let mut file = std::fs::File::open(path).ok()?;
662 let mut header_buf = [0u8; HEADER_SIZE_V2];
663 file.read_exact(&mut header_buf).ok()?;
664 if header_buf[..4] != MAGIC_V2 {
665 return None;
666 }
667 let header = AprV2Header::from_bytes(&header_buf).ok()?;
668 file.seek(SeekFrom::Start(header.metadata_offset)).ok()?;
669 let mut meta_buf = vec![0u8; header.metadata_size as usize];
670 file.read_exact(&mut meta_buf).ok()?;
671 let metadata = AprV2Metadata::from_json(&meta_buf).ok()?;
672 metadata.architecture
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use crate::train::transformer_trainer::LMBatch;
679
680 #[test]
686 fn load_init_tensors_missing_file_errors_with_falsifier_id() {
687 let tmp = tempfile::TempDir::new().expect("tempdir");
688 let missing = tmp.path().join("does-not-exist.apr");
689 let err =
690 load_init_tensors_from_apr(&missing).expect_err("missing init APR file MUST fail-fast");
691 assert!(
692 err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
693 "error must cite falsifier id (auditability): {err}"
694 );
695 assert!(
696 err.contains("does-not-exist.apr"),
697 "error must name the missing path (operator-experience): {err}"
698 );
699 }
700
701 #[test]
710 fn load_init_tensors_signature_compile_bind() {
711 fn _check_signature<F>(_f: F)
715 where
716 F: Fn(&Path) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String>,
717 {
718 }
719 _check_signature(|p| load_init_tensors_from_apr(p));
720 }
721
722 #[test]
723 fn transformer_config_matches_llama_370m_constants() {
724 let cfg = llama_370m_transformer_config();
725 assert_eq!(cfg.hidden_size, Llama370MConfig::HIDDEN_DIM);
726 assert_eq!(cfg.num_hidden_layers, Llama370MConfig::NUM_LAYERS);
727 assert_eq!(cfg.num_attention_heads, Llama370MConfig::NUM_HEADS);
728 assert_eq!(cfg.num_kv_heads, Llama370MConfig::NUM_KV_HEADS);
729 assert_eq!(cfg.intermediate_size, Llama370MConfig::INTERMEDIATE_DIM);
730 assert_eq!(cfg.vocab_size, Llama370MConfig::VOCAB_SIZE);
731 assert!((cfg.rope_theta - Llama370MConfig::ROPE_THETA).abs() < f32::EPSILON);
732 assert!((cfg.rms_norm_eps - Llama370MConfig::RMS_NORM_EPS).abs() < f32::EPSILON);
733 assert!(!cfg.use_bias, "INV-ARCH-370M-008: no bias");
734 assert!(cfg.tie_word_embeddings, "INV-ARCH-370M-004: tied embeddings");
735 }
736
737 #[test]
743 fn build_transformer_config_no_init_matches_llama370m() {
744 let baseline = llama_370m_transformer_config();
745 let result = build_transformer_config(None);
746 assert_eq!(result.hidden_size, baseline.hidden_size);
747 assert_eq!(result.num_attention_heads, baseline.num_attention_heads);
748 assert_eq!(result.num_kv_heads, baseline.num_kv_heads);
749 assert_eq!(result.intermediate_size, baseline.intermediate_size);
750 assert_eq!(result.num_hidden_layers, baseline.num_hidden_layers);
751 assert_eq!(result.vocab_size, baseline.vocab_size);
752 assert_eq!(result.max_position_embeddings, baseline.max_position_embeddings);
753 assert!((result.rms_norm_eps - baseline.rms_norm_eps).abs() < f32::EPSILON);
754 assert!((result.rope_theta - baseline.rope_theta).abs() < f32::EPSILON);
755 assert_eq!(result.use_bias, baseline.use_bias);
756 assert_eq!(result.tie_word_embeddings, baseline.tie_word_embeddings);
757 assert_eq!(result.architecture, baseline.architecture);
758 assert_eq!(result.hf_architecture, baseline.hf_architecture);
759 assert_eq!(result.hf_model_type, baseline.hf_model_type);
760 }
761
762 #[test]
769 fn build_transformer_config_qwen_init_matches_input() {
770 let qwen = TransformerConfig::qwen2_0_5b();
771 let result = build_transformer_config(Some(&qwen));
772 assert_eq!(result.hidden_size, qwen.hidden_size, "hidden_size");
773 assert_eq!(result.num_attention_heads, qwen.num_attention_heads, "num_attention_heads");
774 assert_eq!(result.num_kv_heads, qwen.num_kv_heads, "num_kv_heads");
775 assert_eq!(result.intermediate_size, qwen.intermediate_size, "intermediate_size");
776 assert_eq!(result.num_hidden_layers, qwen.num_hidden_layers, "num_hidden_layers");
777 assert_eq!(result.vocab_size, qwen.vocab_size, "vocab_size");
778 assert_eq!(
779 result.max_position_embeddings, qwen.max_position_embeddings,
780 "max_position_embeddings"
781 );
782 assert_eq!(result.use_bias, qwen.use_bias, "use_bias");
783 assert_eq!(result.tie_word_embeddings, qwen.tie_word_embeddings, "tie_word_embeddings");
784 assert_eq!(result.architecture, qwen.architecture, "architecture");
785 assert_eq!(
787 result.num_attention_heads / result.num_kv_heads,
788 7,
789 "GQA ratio must preserve as 7:1 (Qwen2.5-0.5B canonical)"
790 );
791 }
792
793 #[test]
800 fn build_transformer_config_dispatch_mutually_exclusive() {
801 let qwen = TransformerConfig::qwen2_0_5b();
802 let none_result = build_transformer_config(None);
803 let some_result = build_transformer_config(Some(&qwen));
804 assert_ne!(
806 none_result.hidden_size, some_result.hidden_size,
807 "dispatch must differentiate None vs Some — Llama370M hidden=1024 vs Qwen=896"
808 );
809 assert_ne!(
810 none_result.vocab_size, some_result.vocab_size,
811 "dispatch must differentiate None vs Some — Llama370M vocab=50257 vs Qwen=151936"
812 );
813 }
814
815 #[test]
820 fn validate_pretrain_init_arch_accepts_decoder() {
821 let qwen = TransformerConfig::qwen2_0_5b();
822 assert_eq!(qwen.architecture, ModelArchitecture::Decoder);
823 validate_pretrain_init_arch_compatible(&qwen)
824 .expect("decoder-family config (Qwen2.5-0.5B) MUST pass arch-compat gate");
825 }
826
827 #[test]
835 fn validate_pretrain_init_arch_rejects_encoder() {
836 let bert = TransformerConfig {
838 hidden_size: 768,
839 num_attention_heads: 12,
840 num_kv_heads: 12,
841 intermediate_size: 3072,
842 num_hidden_layers: 12,
843 vocab_size: 50265,
844 max_position_embeddings: 514,
845 rms_norm_eps: 1e-12,
846 rope_theta: 10_000.0,
847 use_bias: true,
848 head_dim_override: None,
849 architecture: ModelArchitecture::Encoder,
850 hf_architecture: Some("RobertaModel".to_string()),
851 hf_model_type: Some("roberta".to_string()),
852 tie_word_embeddings: false,
853 };
854 let err = validate_pretrain_init_arch_compatible(&bert).expect_err(
855 "encoder-family config (CodeBERT/RoBERTa) MUST fail arch-compat gate — \
856 silent acceptance would corrupt §49 fine-tune trajectory before any \
857 FALSIFY-006 check could measure it",
858 );
859 assert!(
860 err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
861 "error must cite falsifier id: {err}"
862 );
863 assert!(err.contains("Encoder"), "error must name the architecture family: {err}");
864 assert!(
865 err.contains("decoder-only"),
866 "error must explain why this is wrong (decoder trainer): {err}"
867 );
868 assert!(
869 err.contains("RobertaModel"),
870 "error must name the offending hf_architecture: {err}"
871 );
872 }
873
874 #[test]
878 fn validate_pretrain_init_arch_accepts_llama370m_baseline() {
879 let llama = llama_370m_transformer_config();
880 assert_eq!(
881 llama.architecture,
882 ModelArchitecture::Decoder,
883 "Llama370M baseline MUST be Decoder (regression-free)"
884 );
885 validate_pretrain_init_arch_compatible(&llama)
886 .expect("Llama370M baseline (Decoder) MUST pass arch-compat gate");
887 }
888
889 #[test]
890 fn real_step_fn_exhausted_iterator_returns_finite_placeholder() {
891 let mut tiny = TransformerConfig::llama2_7b();
899 tiny.hidden_size = 64;
900 tiny.num_attention_heads = 4;
901 tiny.num_kv_heads = 4;
902 tiny.num_hidden_layers = 2;
903 tiny.intermediate_size = 128;
904 tiny.vocab_size = 256;
905 let cfg = TransformerTrainConfig::new(tiny);
906 let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
907 let empty_iter: Box<dyn Iterator<Item = LMBatch>> = Box::new(std::iter::empty::<LMBatch>());
908 let mut step = RealStepFn::new(trainer, empty_iter);
909 let (loss, grad_norm) = step.step(0, 1.0e-4, 128);
910 assert!(loss.is_finite(), "exhausted iter must return finite loss");
911 assert!(grad_norm.is_finite(), "grad_norm must be finite");
912 assert!(grad_norm >= 0.0, "INV-TRAIN-008: grad_norm non-negative");
913 }
914
915 #[test]
916 fn real_val_fn_empty_held_out_returns_nan() {
917 let mut tiny = TransformerConfig::llama2_7b();
918 tiny.hidden_size = 64;
919 tiny.num_attention_heads = 4;
920 tiny.num_kv_heads = 4;
921 tiny.num_hidden_layers = 2;
922 tiny.intermediate_size = 128;
923 tiny.vocab_size = 256;
924 let cfg = TransformerTrainConfig::new(tiny);
925 let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
926 let mut val = RealValFn::new(trainer, Vec::new());
927 let loss = val.validate(0);
928 assert!(loss.is_nan(), "empty held_out must surface as NaN to the guard");
929 }
930
931 fn tiny_test_transformer() -> Transformer {
935 let mut tiny = TransformerConfig::llama2_7b();
936 tiny.hidden_size = 32;
937 tiny.num_attention_heads = 2;
938 tiny.num_kv_heads = 2;
939 tiny.num_hidden_layers = 2;
940 tiny.intermediate_size = 64;
941 tiny.vocab_size = 16;
942 Transformer::new(&tiny)
943 }
944
945 fn tensors_map_from_transformer(
949 transformer: &Transformer,
950 ) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
951 let mut map = BTreeMap::new();
952 for (name, t) in transformer.named_parameters() {
953 let len = t.len();
954 let data: Vec<f32> = (0..len).map(|i| i as f32 * 0.001).collect();
955 map.insert(name, (data, vec![len]));
956 }
957 map
958 }
959
960 #[test]
963 fn populate_trainer_from_init_tensors_happy_path() {
964 let mut transformer = tiny_test_transformer();
965 let init_tensors = tensors_map_from_transformer(&transformer);
966 let expected_count = transformer.named_parameters().len();
967 let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
968 assert!(result.is_ok(), "happy-path populate must succeed: {result:?}");
969 assert_eq!(
970 result.unwrap(),
971 expected_count,
972 "populated count must equal named_parameters().len()"
973 );
974 }
975
976 #[test]
981 fn populate_trainer_from_init_tensors_extra_entries_silently_ignored() {
982 let mut transformer = tiny_test_transformer();
983 let mut init_tensors = tensors_map_from_transformer(&transformer);
984 init_tensors
986 .insert("model.layers.999.fictitious.weight".to_string(), (vec![0.0; 4], vec![4]));
987 let expected_count = transformer.named_parameters().len();
988 let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
989 assert!(result.is_ok(), "extra init entries must NOT cause Err: {result:?}");
990 assert_eq!(result.unwrap(), expected_count);
991 }
992
993 #[test]
997 fn populate_trainer_from_init_tensors_rejects_length_mismatch() {
998 let mut transformer = tiny_test_transformer();
999 let mut init_tensors = tensors_map_from_transformer(&transformer);
1000 let any_name = transformer.named_parameters()[0].0.clone();
1002 init_tensors.insert(any_name.clone(), (vec![0.0; 7], vec![7]));
1003 let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
1004 assert!(result.is_err(), "length-mismatch must Err, not silently truncate");
1005 let err = result.unwrap_err();
1006 assert!(
1007 err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
1008 "error must cite falsifier id; got: {err}"
1009 );
1010 assert!(err.contains(&any_name), "error must name the offending parameter; got: {err}");
1011 assert!(
1012 err.contains("init length 7"),
1013 "error must report the actual init length; got: {err}"
1014 );
1015 }
1016
1017 #[test]
1023 fn populate_trainer_from_init_tensors_rejects_missing_required_param() {
1024 let mut transformer = tiny_test_transformer();
1025 let mut init_tensors = tensors_map_from_transformer(&transformer);
1026 let any_name = transformer.named_parameters()[0].0.clone();
1028 init_tensors.remove(&any_name);
1029 let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
1030 assert!(result.is_err(), "missing-required must Err, not silently leave random init");
1031 let err = result.unwrap_err();
1032 assert!(
1033 err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
1034 "error must cite falsifier id; got: {err}"
1035 );
1036 assert!(err.contains(&any_name), "error must name the missing parameter; got: {err}");
1037 assert!(
1038 err.contains("not present in init APR"),
1039 "error must say what was missing; got: {err}"
1040 );
1041 }
1042
1043 #[test]
1050 fn build_shared_trainer_with_init_none_uses_llama370m_shape() {
1051 let trainer = build_shared_trainer_with_init(1.0e-4, 128, 42, None, None)
1052 .expect("None case must succeed");
1053 let model = trainer.borrow();
1054 let embed_len = model.model().named_parameters()[0].1.len();
1057 let expected_embed_len = Llama370MConfig::VOCAB_SIZE * Llama370MConfig::HIDDEN_DIM;
1058 assert_eq!(
1059 embed_len,
1060 expected_embed_len,
1061 "init=None must produce Llama370M-shaped embedding (vocab={} × hidden={})",
1062 Llama370MConfig::VOCAB_SIZE,
1063 Llama370MConfig::HIDDEN_DIM
1064 );
1065 }
1066
1067 #[test]
1071 fn build_shared_trainer_with_init_rejects_unpaired_args() {
1072 let cfg = TransformerConfig::qwen2_0_5b();
1074 let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), None);
1075 assert!(result.is_err(), "unpaired (arch=Some, path=None) must Err");
1076 let dummy_path = std::path::PathBuf::from("/dev/null");
1078 let result = build_shared_trainer_with_init(1.0e-4, 128, 42, None, Some(&dummy_path));
1079 assert!(result.is_err(), "unpaired (arch=None, path=Some) must Err");
1080 }
1081
1082 #[test]
1086 fn build_shared_trainer_with_init_rejects_encoder_family() {
1087 let mut encoder_cfg = TransformerConfig::qwen2_0_5b();
1088 encoder_cfg.architecture = ModelArchitecture::Encoder;
1089 let dummy_path = std::path::PathBuf::from("/nonexistent/encoder.apr");
1090 let result =
1091 build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&encoder_cfg), Some(&dummy_path));
1092 let err = match result {
1093 Ok(_) => panic!("encoder family must be rejected before tensor load"),
1094 Err(e) => e,
1095 };
1096 assert!(
1097 err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
1098 "error must cite falsifier id; got: {err}"
1099 );
1100 }
1101
1102 #[test]
1107 fn build_shared_trainer_with_init_decoder_family_proceeds_to_tensor_load() {
1108 let cfg = TransformerConfig::qwen2_0_5b();
1109 let dummy_path = std::path::PathBuf::from("/nonexistent/decoder.apr");
1110 let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), Some(&dummy_path));
1111 let err = match result {
1112 Ok(_) => panic!("missing tensor path must Err"),
1113 Err(e) => e,
1114 };
1115 assert!(
1116 err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
1117 "decoder family proceeds to tensor load; failure cites INIT-006 not ARCH-007; got: {err}"
1118 );
1119 assert!(
1120 !err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
1121 "decoder family must NOT trigger encoder-rejection; got: {err}"
1122 );
1123 }
1124
1125 #[test]
1152 fn falsify_h4_init_stats_qwen_embed_norm_sensible() {
1153 let fresh = std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-fresh.apr");
1154 let legacy =
1155 std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-instruct-fp16.apr");
1156 let path = if fresh.exists() {
1157 fresh
1158 } else if legacy.exists() {
1159 legacy
1160 } else {
1161 eprintln!("[falsify-h4-init-stats-001] skipping: host lacks Qwen 0.5B APR");
1162 return;
1163 };
1164 let _ = path; if !path.exists() {
1166 eprintln!("[falsify-h4-init-stats-001] skipping: host lacks {}", path.display());
1167 return;
1168 }
1169 {
1173 use aprender::format::v2::AprV2Reader;
1174 let bytes = std::fs::read(path).expect("read APR");
1175 let reader = AprV2Reader::from_bytes(&bytes).expect("parse APR v2");
1176 for name in ["model.layers.0.self_attn.q_proj.bias", "model.norm.weight"] {
1177 if let Some(entry) = reader.get_tensor(name) {
1178 eprintln!(
1179 "[h4-init-dtype] {name}: dtype={:?} shape={:?}",
1180 entry.dtype, entry.shape
1181 );
1182 }
1183 }
1184 }
1185 let tensors = match load_init_tensors_from_apr(path) {
1186 Ok(t) => t,
1187 Err(e) => {
1188 panic!("FALSIFY-H4-INIT-STATS-001: load_init_tensors_from_apr failed: {e}");
1189 }
1190 };
1191
1192 let embed = tensors
1194 .get("model.embed_tokens.weight")
1195 .unwrap_or_else(|| panic!("missing model.embed_tokens.weight in init APR"));
1196 let norm = tensors
1197 .get("model.norm.weight")
1198 .unwrap_or_else(|| panic!("missing model.norm.weight in init APR"));
1199
1200 let stats = |name: &str, data: &[f32]| -> (f64, f64, f32, f32) {
1201 let n = data.len() as f64;
1202 let mean = data.iter().map(|&v| v as f64).sum::<f64>() / n;
1203 let var = data
1204 .iter()
1205 .map(|&v| {
1206 let d = v as f64 - mean;
1207 d * d
1208 })
1209 .sum::<f64>()
1210 / n;
1211 let std = var.sqrt();
1212 let min = data.iter().copied().fold(f32::INFINITY, f32::min);
1213 let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1214 eprintln!(
1215 "[h4-init-stats] {name}: n={n} mean={mean:.5} std={std:.5} min={min:.4} max={max:.4}"
1216 );
1217 (mean, std, min, max)
1218 };
1219 {
1223 let q = tensors.get("model.layers.0.self_attn.q_proj.bias").unwrap();
1224 eprintln!(
1225 "[h4-dtype-mislabel] q_proj.bias L0[0..6] (aprender F16-decoded): {:?}",
1226 &q.0[..6]
1227 );
1228 let n = tensors.get("model.norm.weight").unwrap();
1229 eprintln!(
1230 "[h4-dtype-mislabel] model.norm.weight[0..6] (aprender F16-decoded): {:?}",
1231 &n.0[..6]
1232 );
1233 }
1234
1235 let (em, es, _, _) = stats("model.embed_tokens.weight", &embed.0);
1236 let (nm, ns, _, _) = stats("model.norm.weight", &norm.0);
1237
1238 for layer_idx in [0_usize, 5, 11, 23] {
1242 for kind in ["input_layernorm", "post_attention_layernorm"] {
1243 let key = format!("model.layers.{layer_idx}.{kind}.weight");
1244 if let Some(t) = tensors.get(&key) {
1245 stats(&key, &t.0);
1246 }
1247 }
1248 }
1249 for kind in [
1250 "model.layers.0.self_attn.q_proj.weight",
1251 "model.layers.0.self_attn.q_proj.bias",
1252 "model.layers.0.mlp.gate_proj.weight",
1253 "model.layers.0.mlp.down_proj.weight",
1254 ] {
1255 if let Some(t) = tensors.get(kind) {
1256 stats(kind, &t.0);
1257 }
1258 }
1259
1260 assert!(
1264 em.abs() < 0.05,
1265 "FALSIFY-H4-INIT-STATS-001: embed mean={em} > 0.05; weights are not centered. \
1266 Possible f16→f32 sign-bit corruption or wrong byte-order."
1267 );
1268 assert!(
1269 (0.005..=0.5).contains(&es),
1270 "FALSIFY-H4-INIT-STATS-001: embed std={es} outside [0.005, 0.5]; weights are not \
1271 distributed like trained transformer init. Possible f16 mantissa misread or \
1272 scale corruption."
1273 );
1274
1275 assert!(
1278 nm > 0.01 && nm < 100.0,
1279 "FALSIFY-H4-INIT-STATS-001: norm mean={nm} outside [0.01, 100]; RMSNorm scale \
1280 load is corrupt. Trained pretrained values are typically near 1.0."
1281 );
1282 assert!(
1283 ns < 100.0,
1284 "FALSIFY-H4-INIT-STATS-001: norm std={ns} > 100; RMSNorm has explosive variance. \
1285 Tensor load is corrupt."
1286 );
1287 }
1288
1289 #[test]
1296 fn falsify_h4_cpu_forward_qwen_logits_sensible() {
1297 let fresh = std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-fresh.apr");
1298 let legacy =
1299 std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-instruct-fp16.apr");
1300 let path = if fresh.exists() {
1301 fresh
1302 } else if legacy.exists() {
1303 legacy
1304 } else {
1305 eprintln!("[falsify-h4-cpu-forward-001] skipping: host lacks Qwen 0.5B APR");
1306 return;
1307 };
1308
1309 let tensors = load_init_tensors_from_apr(path).expect("load_init_tensors_from_apr");
1310 let cfg = TransformerConfig::qwen2_0_5b();
1311 let mut transformer = Transformer::new(&cfg);
1312 let populated = populate_trainer_from_init_tensors(&mut transformer, &tensors)
1313 .expect("populate_trainer_from_init_tensors");
1314 eprintln!("[falsify-h4-cpu-forward-001] populated {populated} tensors");
1315
1316 let token_ids = vec![100_u32];
1317 let logits = transformer.forward(&token_ids);
1318 let data = logits.data();
1319 let slice = data.as_slice().expect("logits contiguous");
1320
1321 let mut nan_count = 0usize;
1322 let mut inf_count = 0usize;
1323 let mut min = f32::INFINITY;
1324 let mut max = f32::NEG_INFINITY;
1325 let mut sum = 0.0_f64;
1326 let mut sum_sq = 0.0_f64;
1327 let mut argmax_idx = 0_usize;
1328 for (i, &v) in slice.iter().enumerate() {
1329 if v.is_nan() {
1330 nan_count += 1;
1331 } else if v.is_infinite() {
1332 inf_count += 1;
1333 } else {
1334 if v < min {
1335 min = v;
1336 }
1337 if v > max {
1338 max = v;
1339 argmax_idx = i;
1340 }
1341 sum += v as f64;
1342 sum_sq += (v as f64) * (v as f64);
1343 }
1344 }
1345 let n = slice.len() as f64;
1346 let mean = sum / n;
1347 let std = (sum_sq / n - mean * mean).sqrt();
1348
1349 eprintln!(
1350 "[falsify-h4-cpu-forward-001] token=100 logits: n={} nan={nan_count} inf={inf_count} \
1351 min={min:.4} max={max:.4} mean={mean:.4} std={std:.4} argmax={argmax_idx}",
1352 slice.len()
1353 );
1354
1355 assert_eq!(nan_count, 0, "logits contain NaN — forward corruption");
1356 assert_eq!(inf_count, 0, "logits contain Inf — forward corruption");
1357 assert!(
1358 std > 0.01,
1359 "FALSIFY-H4-CPU-FORWARD-001: logits std={std} < 0.01 — essentially constant"
1360 );
1361 let peak_to_mean = (max as f64 - mean).abs() / std.max(1e-9);
1362 assert!(
1363 peak_to_mean > 1.5,
1364 "FALSIFY-H4-CPU-FORWARD-001: peak-to-mean ratio = {peak_to_mean} < 1.5 — \
1365 logits are essentially uniform"
1366 );
1367 assert!(
1368 (argmax_idx as u32) < cfg.vocab_size as u32,
1369 "FALSIFY-H4-CPU-FORWARD-001: argmax_idx={argmax_idx} >= vocab_size={}",
1370 cfg.vocab_size
1371 );
1372 }
1373
1374 fn qwen2_tensor_names() -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
1379 let mut m = BTreeMap::new();
1381 m.insert("model.layers.0.self_attn.q_proj.bias".to_string(), (vec![0.0_f32; 4], vec![4]));
1382 m.insert(
1383 "model.layers.0.self_attn.q_proj.weight".to_string(),
1384 (vec![0.0_f32; 16], vec![4, 4]),
1385 );
1386 m
1387 }
1388
1389 fn llama_tensor_names() -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
1390 let mut m = BTreeMap::new();
1392 m.insert(
1393 "model.layers.0.self_attn.q_proj.weight".to_string(),
1394 (vec![0.0_f32; 16], vec![4, 4]),
1395 );
1396 m.insert("model.layers.0.input_layernorm.weight".to_string(), (vec![1.0_f32; 4], vec![4]));
1397 m
1398 }
1399
1400 #[test]
1404 fn inv_init_arch_match_001_rejects_llama_stamped_qwen2_tensors() {
1405 let tensors = qwen2_tensor_names();
1406 let err = validate_init_arch_matches_tensor_evidence(Some("LlamaForCausalLM"), &tensors)
1407 .expect_err("§86 case MUST be rejected");
1408 assert!(
1409 err.contains("FALSIFY-INIT-ARCH-MATCH-001"),
1410 "error must cite falsifier id; got: {err}"
1411 );
1412 assert!(
1413 err.contains("llama") && err.contains("qwen2"),
1414 "error must name both claimed and inferred families; got: {err}"
1415 );
1416 assert!(
1417 err.contains("apr stamp"),
1418 "error must include the §86.4 salvage recipe; got: {err}"
1419 );
1420 }
1421
1422 #[test]
1425 fn inv_init_arch_match_001_rejects_qwen2_stamped_llama_tensors() {
1426 let tensors = llama_tensor_names();
1427 let err = validate_init_arch_matches_tensor_evidence(Some("Qwen2ForCausalLM"), &tensors)
1428 .expect_err("inverse §86 case MUST be rejected");
1429 assert!(err.contains("FALSIFY-INIT-ARCH-MATCH-001"));
1430 assert!(err.contains("qwen2") && err.contains("llama"));
1431 }
1432
1433 #[test]
1435 fn inv_init_arch_match_001_accepts_matching_qwen2() {
1436 let tensors = qwen2_tensor_names();
1437 validate_init_arch_matches_tensor_evidence(Some("Qwen2ForCausalLM"), &tensors)
1438 .expect("matching qwen2 + qwen2 must pass");
1439 validate_init_arch_matches_tensor_evidence(Some("qwen2"), &tensors)
1440 .expect("matching qwen2 slug + qwen2 tensors must pass");
1441 }
1442
1443 #[test]
1445 fn inv_init_arch_match_001_accepts_matching_llama() {
1446 let tensors = llama_tensor_names();
1447 validate_init_arch_matches_tensor_evidence(Some("LlamaForCausalLM"), &tensors)
1448 .expect("matching llama + llama must pass");
1449 validate_init_arch_matches_tensor_evidence(Some("llama"), &tensors)
1450 .expect("matching llama slug + llama tensors must pass");
1451 }
1452
1453 #[test]
1456 fn inv_init_arch_match_001_skips_when_metadata_absent() {
1457 let tensors = qwen2_tensor_names();
1458 validate_init_arch_matches_tensor_evidence(None, &tensors)
1459 .expect("absent metadata claim must skip check");
1460 }
1461
1462 #[test]
1465 fn inv_init_arch_match_001_skips_unmappable_metadata() {
1466 let tensors = qwen2_tensor_names();
1467 validate_init_arch_matches_tensor_evidence(Some("WeirdNovelArchForCausalLM"), &tensors)
1468 .expect("unmappable metadata MUST skip check (no false-positive on novel arch)");
1469 }
1470
1471 #[test]
1474 fn inv_init_arch_match_001_trusts_metadata_when_tensors_unknown() {
1475 let mut tensors = BTreeMap::new();
1476 tensors.insert("blk.0.attn_q.weight".to_string(), (vec![0.0_f32; 16], vec![4, 4]));
1477 validate_init_arch_matches_tensor_evidence(Some("LlamaForCausalLM"), &tensors)
1479 .expect("unknown tensor family must skip check (trust metadata)");
1480 }
1481
1482 #[test]
1485 fn family_from_tensor_names_distinguishes_qwen2_from_llama() {
1486 let qwen2: Vec<&str> = vec![
1487 "model.layers.0.self_attn.q_proj.weight",
1488 "model.layers.0.self_attn.q_proj.bias", ];
1490 assert_eq!(family_from_tensor_names(qwen2.iter().copied()), "qwen2");
1491
1492 let llama: Vec<&str> =
1493 vec!["model.layers.0.self_attn.q_proj.weight", "model.layers.0.input_layernorm.weight"];
1494 assert_eq!(family_from_tensor_names(llama.iter().copied()), "llama");
1495 }
1496
1497 #[test]
1499 fn normalize_metadata_arch_family_handles_three_forms() {
1500 assert_eq!(normalize_metadata_arch_family("Qwen2ForCausalLM"), Some("qwen2"));
1502 assert_eq!(normalize_metadata_arch_family("LlamaForCausalLM"), Some("llama"));
1503 assert_eq!(normalize_metadata_arch_family("qwen2"), Some("qwen2"));
1505 assert_eq!(normalize_metadata_arch_family("llama"), Some("llama"));
1506 assert_eq!(normalize_metadata_arch_family("Qwen2"), Some("qwen2"));
1508 assert_eq!(normalize_metadata_arch_family("unknown"), None);
1510 assert_eq!(normalize_metadata_arch_family("WeirdNovelArch"), None);
1511 }
1512}