1use crate::models::llama_370m::Llama370MConfig;
23use crate::train::pretrain::{CheckpointFn, EpochArtifact, StepFn, ValFn};
24use crate::train::transformer_trainer::{LMBatch, TransformerTrainConfig, TransformerTrainer};
25use crate::transformer::{ModelArchitecture, Transformer, TransformerConfig};
26use crate::Tensor;
27use std::cell::RefCell;
28use std::collections::BTreeMap;
29use std::path::Path;
30use std::rc::Rc;
31
32pub type SharedTrainer = Rc<RefCell<TransformerTrainer>>;
36
37pub fn load_init_tensors_from_apr(
65 path: impl AsRef<Path>,
66) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String> {
67 let path_ref = path.as_ref();
68 aprender::format::converter::load_model_tensors(path_ref).map_err(|e| {
69 format!(
70 "FALSIFY-APR-PRETRAIN-INIT-006: failed to load init tensors from APR file {}: {e}",
71 path_ref.display()
72 )
73 })
74}
75
76pub fn validate_pretrain_init_arch_compatible(cfg: &TransformerConfig) -> Result<(), String> {
98 match cfg.architecture {
99 ModelArchitecture::Decoder => Ok(()),
100 ModelArchitecture::Encoder => Err(format!(
101 "FALSIFY-APR-PRETRAIN-ARCH-007: --init checkpoint has architecture=Encoder \
102 (e.g., BERT/RoBERTa/CodeBERT) but the pretrain trainer is decoder-only \
103 (Llama/Qwen-class causal LMs). Loading encoder weights into a decoder \
104 trainer would produce nonsense gradients. Architectural details: \
105 hidden_size={}, num_layers={}, vocab_size={}, hf_architecture={:?}",
106 cfg.hidden_size, cfg.num_hidden_layers, cfg.vocab_size, cfg.hf_architecture
107 )),
108 }
109}
110
111pub fn populate_trainer_from_init_tensors(
142 transformer: &mut Transformer,
143 init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
144) -> Result<usize, String> {
145 let expected: Vec<(String, usize)> = transformer
146 .named_parameters()
147 .into_iter()
148 .map(|(name, t)| (name, t.len()))
149 .collect();
150 let mut populated = 0usize;
151 let mut errors: Vec<String> = Vec::new();
152
153 for (name, expected_len) in &expected {
154 match init_tensors.get(name) {
155 Some((data, _shape)) => {
156 if data.len() != *expected_len {
157 errors.push(format!(
158 "{name}: init length {} != trainer expected {expected_len}",
159 data.len()
160 ));
161 continue;
162 }
163 let tensor = Tensor::from_vec(data.clone(), true);
164 if !transformer.set_named_parameter(name, tensor) {
165 errors.push(format!(
166 "{name}: set_named_parameter rejected the assignment"
167 ));
168 continue;
169 }
170 populated += 1;
171 }
172 None => {
173 errors.push(format!("{name}: not present in init APR tensors"));
174 }
175 }
176 }
177
178 if !errors.is_empty() {
179 let total = errors.len();
180 let head = errors.iter().take(5).cloned().collect::<Vec<_>>().join("; ");
181 return Err(format!(
182 "FALSIFY-APR-PRETRAIN-INIT-007: populate_trainer_from_init_tensors \
183 failed for {total} parameter(s); first {} of {total}: {head}",
184 errors.len().min(5)
185 ));
186 }
187
188 Ok(populated)
189}
190
191pub fn llama_370m_transformer_config() -> TransformerConfig {
194 TransformerConfig {
195 hidden_size: Llama370MConfig::HIDDEN_DIM,
196 num_attention_heads: Llama370MConfig::NUM_HEADS,
197 num_kv_heads: Llama370MConfig::NUM_KV_HEADS,
198 intermediate_size: Llama370MConfig::INTERMEDIATE_DIM,
199 num_hidden_layers: Llama370MConfig::NUM_LAYERS,
200 vocab_size: Llama370MConfig::VOCAB_SIZE,
201 max_position_embeddings: Llama370MConfig::MAX_POSITION_EMBEDDINGS,
202 rms_norm_eps: Llama370MConfig::RMS_NORM_EPS,
203 rope_theta: Llama370MConfig::ROPE_THETA,
204 use_bias: false,
205 head_dim_override: None,
206 architecture: ModelArchitecture::Decoder,
207 hf_architecture: Some("LlamaForCausalLM".into()),
208 hf_model_type: Some("llama".into()),
209 tie_word_embeddings: true,
210 }
211}
212
213pub fn build_transformer_config(init: Option<&TransformerConfig>) -> TransformerConfig {
231 match init {
232 None => llama_370m_transformer_config(),
233 Some(cfg) => cfg.clone(),
234 }
235}
236
237pub fn llama_370m_train_config(lr: f32, seq_length: usize, seed: u64) -> TransformerTrainConfig {
240 let model_cfg = llama_370m_transformer_config();
241 let mut cfg = TransformerTrainConfig::new(model_cfg);
242 cfg.lr = lr;
243 cfg.max_seq_len = seq_length;
244 cfg.seed = seed;
245 cfg
246}
247
248pub struct RealStepFn {
252 trainer: SharedTrainer,
253 batches: Box<dyn Iterator<Item = LMBatch>>,
254}
255
256impl RealStepFn {
257 pub fn new(trainer: SharedTrainer, batches: Box<dyn Iterator<Item = LMBatch>>) -> Self {
258 Self { trainer, batches }
259 }
260}
261
262impl StepFn for RealStepFn {
263 fn step(&mut self, _step: u64, _lr: f32, _batch_tokens: u64) -> (f32, f32) {
264 let Some(batch) = self.batches.next() else {
270 return (1.0, 1.0);
271 };
272 let mut trainer = self.trainer.borrow_mut();
273 let loss = trainer.train_batch(&batch);
274 let grad_norm = 1.0_f32;
279 (loss, grad_norm)
280 }
281
282 fn optimizer_state_sha256(&self) -> Option<String> {
284 Some(self.trainer.borrow().optimizer_state_sha256())
285 }
286}
287
288pub struct RealValFn {
291 trainer: SharedTrainer,
292 held_out: Vec<LMBatch>,
293}
294
295impl RealValFn {
296 pub fn new(trainer: SharedTrainer, held_out: Vec<LMBatch>) -> Self {
297 Self { trainer, held_out }
298 }
299}
300
301impl ValFn for RealValFn {
302 fn validate(&mut self, _epoch: usize) -> f32 {
303 if self.held_out.is_empty() {
304 return f32::NAN;
305 }
306 let trainer = self.trainer.borrow();
307 let mut total_loss = 0.0_f32;
308 let mut total_items = 0_usize;
309 for batch in &self.held_out {
310 for i in 0..batch.batch_size {
311 let Some(inp) = batch.get_input(i) else {
312 continue;
313 };
314 let Some(tgt) = batch.get_target(i) else {
315 continue;
316 };
317 let (loss_val, _loss_tensor, _logits) = trainer.forward_single(inp, tgt);
318 total_loss += loss_val;
319 total_items += 1;
320 }
321 }
322 if total_items == 0 {
323 f32::NAN
324 } else {
325 total_loss / total_items as f32
326 }
327 }
328}
329
330pub struct AprCheckpointFn {
336 trainer: SharedTrainer,
337 model_name: String,
338 architecture: String,
339}
340
341impl AprCheckpointFn {
342 pub fn new(
343 trainer: SharedTrainer,
344 model_name: impl Into<String>,
345 architecture: impl Into<String>,
346 ) -> Self {
347 Self { trainer, model_name: model_name.into(), architecture: architecture.into() }
348 }
349}
350
351impl CheckpointFn for AprCheckpointFn {
352 fn save(&mut self, _epoch: usize, artifact: &EpochArtifact) -> Result<(), String> {
353 let trainer = self.trainer.borrow();
354 trainer
355 .save_apr(&artifact.checkpoint_path, &self.model_name, &self.architecture)
356 .map_err(|e| format!("save_apr failed: {e}"))
357 }
358}
359
360pub fn build_shared_trainer(lr: f32, seq_length: usize, seed: u64) -> SharedTrainer {
363 let cfg = llama_370m_train_config(lr, seq_length, seed);
364 let trainer = TransformerTrainer::new(cfg);
365 #[cfg(debug_assertions)]
370 {
371 let param_count: usize = trainer.model().parameters().iter().map(|t| t.len()).sum();
372 debug_assert!(
373 (366_000_000..=374_000_000).contains(¶m_count),
374 "INV-ARCH-370M-001: parameter count {param_count} outside [366M, 374M] band",
375 );
376 }
377 Rc::new(RefCell::new(trainer))
378}
379
380pub fn build_shared_trainer_with_init(
407 lr: f32,
408 seq_length: usize,
409 seed: u64,
410 init_arch: Option<&TransformerConfig>,
411 init_path: Option<&Path>,
412) -> Result<SharedTrainer, String> {
413 if init_arch.is_some() != init_path.is_some() {
414 return Err(format!(
415 "build_shared_trainer_with_init: init_arch and init_path must both be Some \
416 or both None (caller bug; init_arch.is_some()={}, init_path.is_some()={})",
417 init_arch.is_some(),
418 init_path.is_some()
419 ));
420 }
421
422 if let Some(cfg) = init_arch {
423 validate_pretrain_init_arch_compatible(cfg)?;
424 }
425
426 let model_cfg = build_transformer_config(init_arch);
427 let mut train_cfg = TransformerTrainConfig::new(model_cfg);
428 train_cfg.lr = lr;
429 train_cfg.max_seq_len = seq_length;
430 train_cfg.seed = seed;
431 let mut trainer = TransformerTrainer::new(train_cfg);
432
433 if let Some(path) = init_path {
440 let tensors = load_init_tensors_from_apr(path)?;
441 populate_trainer_from_init_tensors(trainer.model_mut(), &tensors)?;
442 }
443
444 Ok(Rc::new(RefCell::new(trainer)))
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use crate::train::transformer_trainer::LMBatch;
451
452 #[test]
458 fn load_init_tensors_missing_file_errors_with_falsifier_id() {
459 let tmp = tempfile::TempDir::new().expect("tempdir");
460 let missing = tmp.path().join("does-not-exist.apr");
461 let err = load_init_tensors_from_apr(&missing)
462 .expect_err("missing init APR file MUST fail-fast");
463 assert!(
464 err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
465 "error must cite falsifier id (auditability): {err}"
466 );
467 assert!(
468 err.contains("does-not-exist.apr"),
469 "error must name the missing path (operator-experience): {err}"
470 );
471 }
472
473 #[test]
482 fn load_init_tensors_signature_compile_bind() {
483 fn _check_signature<F>(_f: F)
487 where
488 F: Fn(
489 &Path,
490 )
491 -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String>,
492 {
493 }
494 _check_signature(|p| load_init_tensors_from_apr(p));
495 }
496
497 #[test]
498 fn transformer_config_matches_llama_370m_constants() {
499 let cfg = llama_370m_transformer_config();
500 assert_eq!(cfg.hidden_size, Llama370MConfig::HIDDEN_DIM);
501 assert_eq!(cfg.num_hidden_layers, Llama370MConfig::NUM_LAYERS);
502 assert_eq!(cfg.num_attention_heads, Llama370MConfig::NUM_HEADS);
503 assert_eq!(cfg.num_kv_heads, Llama370MConfig::NUM_KV_HEADS);
504 assert_eq!(cfg.intermediate_size, Llama370MConfig::INTERMEDIATE_DIM);
505 assert_eq!(cfg.vocab_size, Llama370MConfig::VOCAB_SIZE);
506 assert!((cfg.rope_theta - Llama370MConfig::ROPE_THETA).abs() < f32::EPSILON);
507 assert!((cfg.rms_norm_eps - Llama370MConfig::RMS_NORM_EPS).abs() < f32::EPSILON);
508 assert!(!cfg.use_bias, "INV-ARCH-370M-008: no bias");
509 assert!(cfg.tie_word_embeddings, "INV-ARCH-370M-004: tied embeddings");
510 }
511
512 #[test]
518 fn build_transformer_config_no_init_matches_llama370m() {
519 let baseline = llama_370m_transformer_config();
520 let result = build_transformer_config(None);
521 assert_eq!(result.hidden_size, baseline.hidden_size);
522 assert_eq!(result.num_attention_heads, baseline.num_attention_heads);
523 assert_eq!(result.num_kv_heads, baseline.num_kv_heads);
524 assert_eq!(result.intermediate_size, baseline.intermediate_size);
525 assert_eq!(result.num_hidden_layers, baseline.num_hidden_layers);
526 assert_eq!(result.vocab_size, baseline.vocab_size);
527 assert_eq!(
528 result.max_position_embeddings,
529 baseline.max_position_embeddings
530 );
531 assert!((result.rms_norm_eps - baseline.rms_norm_eps).abs() < f32::EPSILON);
532 assert!((result.rope_theta - baseline.rope_theta).abs() < f32::EPSILON);
533 assert_eq!(result.use_bias, baseline.use_bias);
534 assert_eq!(result.tie_word_embeddings, baseline.tie_word_embeddings);
535 assert_eq!(result.architecture, baseline.architecture);
536 assert_eq!(result.hf_architecture, baseline.hf_architecture);
537 assert_eq!(result.hf_model_type, baseline.hf_model_type);
538 }
539
540 #[test]
547 fn build_transformer_config_qwen_init_matches_input() {
548 let qwen = TransformerConfig::qwen2_0_5b();
549 let result = build_transformer_config(Some(&qwen));
550 assert_eq!(result.hidden_size, qwen.hidden_size, "hidden_size");
551 assert_eq!(
552 result.num_attention_heads, qwen.num_attention_heads,
553 "num_attention_heads"
554 );
555 assert_eq!(result.num_kv_heads, qwen.num_kv_heads, "num_kv_heads");
556 assert_eq!(
557 result.intermediate_size, qwen.intermediate_size,
558 "intermediate_size"
559 );
560 assert_eq!(
561 result.num_hidden_layers, qwen.num_hidden_layers,
562 "num_hidden_layers"
563 );
564 assert_eq!(result.vocab_size, qwen.vocab_size, "vocab_size");
565 assert_eq!(
566 result.max_position_embeddings, qwen.max_position_embeddings,
567 "max_position_embeddings"
568 );
569 assert_eq!(result.use_bias, qwen.use_bias, "use_bias");
570 assert_eq!(
571 result.tie_word_embeddings, qwen.tie_word_embeddings,
572 "tie_word_embeddings"
573 );
574 assert_eq!(result.architecture, qwen.architecture, "architecture");
575 assert_eq!(
577 result.num_attention_heads / result.num_kv_heads,
578 7,
579 "GQA ratio must preserve as 7:1 (Qwen2.5-0.5B canonical)"
580 );
581 }
582
583 #[test]
590 fn build_transformer_config_dispatch_mutually_exclusive() {
591 let qwen = TransformerConfig::qwen2_0_5b();
592 let none_result = build_transformer_config(None);
593 let some_result = build_transformer_config(Some(&qwen));
594 assert_ne!(
596 none_result.hidden_size, some_result.hidden_size,
597 "dispatch must differentiate None vs Some — Llama370M hidden=1024 vs Qwen=896"
598 );
599 assert_ne!(
600 none_result.vocab_size, some_result.vocab_size,
601 "dispatch must differentiate None vs Some — Llama370M vocab=50257 vs Qwen=151936"
602 );
603 }
604
605 #[test]
610 fn validate_pretrain_init_arch_accepts_decoder() {
611 let qwen = TransformerConfig::qwen2_0_5b();
612 assert_eq!(qwen.architecture, ModelArchitecture::Decoder);
613 validate_pretrain_init_arch_compatible(&qwen)
614 .expect("decoder-family config (Qwen2.5-0.5B) MUST pass arch-compat gate");
615 }
616
617 #[test]
625 fn validate_pretrain_init_arch_rejects_encoder() {
626 let bert = TransformerConfig {
628 hidden_size: 768,
629 num_attention_heads: 12,
630 num_kv_heads: 12,
631 intermediate_size: 3072,
632 num_hidden_layers: 12,
633 vocab_size: 50265,
634 max_position_embeddings: 514,
635 rms_norm_eps: 1e-12,
636 rope_theta: 10_000.0,
637 use_bias: true,
638 head_dim_override: None,
639 architecture: ModelArchitecture::Encoder,
640 hf_architecture: Some("RobertaModel".to_string()),
641 hf_model_type: Some("roberta".to_string()),
642 tie_word_embeddings: false,
643 };
644 let err = validate_pretrain_init_arch_compatible(&bert).expect_err(
645 "encoder-family config (CodeBERT/RoBERTa) MUST fail arch-compat gate — \
646 silent acceptance would corrupt §49 fine-tune trajectory before any \
647 FALSIFY-006 check could measure it",
648 );
649 assert!(
650 err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
651 "error must cite falsifier id: {err}"
652 );
653 assert!(
654 err.contains("Encoder"),
655 "error must name the architecture family: {err}"
656 );
657 assert!(
658 err.contains("decoder-only"),
659 "error must explain why this is wrong (decoder trainer): {err}"
660 );
661 assert!(
662 err.contains("RobertaModel"),
663 "error must name the offending hf_architecture: {err}"
664 );
665 }
666
667 #[test]
671 fn validate_pretrain_init_arch_accepts_llama370m_baseline() {
672 let llama = llama_370m_transformer_config();
673 assert_eq!(
674 llama.architecture,
675 ModelArchitecture::Decoder,
676 "Llama370M baseline MUST be Decoder (regression-free)"
677 );
678 validate_pretrain_init_arch_compatible(&llama)
679 .expect("Llama370M baseline (Decoder) MUST pass arch-compat gate");
680 }
681
682 #[test]
683 fn real_step_fn_exhausted_iterator_returns_finite_placeholder() {
684 let mut tiny = TransformerConfig::llama2_7b();
692 tiny.hidden_size = 64;
693 tiny.num_attention_heads = 4;
694 tiny.num_kv_heads = 4;
695 tiny.num_hidden_layers = 2;
696 tiny.intermediate_size = 128;
697 tiny.vocab_size = 256;
698 let cfg = TransformerTrainConfig::new(tiny);
699 let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
700 let empty_iter: Box<dyn Iterator<Item = LMBatch>> = Box::new(std::iter::empty::<LMBatch>());
701 let mut step = RealStepFn::new(trainer, empty_iter);
702 let (loss, grad_norm) = step.step(0, 1.0e-4, 128);
703 assert!(loss.is_finite(), "exhausted iter must return finite loss");
704 assert!(grad_norm.is_finite(), "grad_norm must be finite");
705 assert!(grad_norm >= 0.0, "INV-TRAIN-008: grad_norm non-negative");
706 }
707
708 #[test]
709 fn real_val_fn_empty_held_out_returns_nan() {
710 let mut tiny = TransformerConfig::llama2_7b();
711 tiny.hidden_size = 64;
712 tiny.num_attention_heads = 4;
713 tiny.num_kv_heads = 4;
714 tiny.num_hidden_layers = 2;
715 tiny.intermediate_size = 128;
716 tiny.vocab_size = 256;
717 let cfg = TransformerTrainConfig::new(tiny);
718 let trainer = Rc::new(RefCell::new(TransformerTrainer::new(cfg)));
719 let mut val = RealValFn::new(trainer, Vec::new());
720 let loss = val.validate(0);
721 assert!(loss.is_nan(), "empty held_out must surface as NaN to the guard");
722 }
723
724 fn tiny_test_transformer() -> Transformer {
728 let mut tiny = TransformerConfig::llama2_7b();
729 tiny.hidden_size = 32;
730 tiny.num_attention_heads = 2;
731 tiny.num_kv_heads = 2;
732 tiny.num_hidden_layers = 2;
733 tiny.intermediate_size = 64;
734 tiny.vocab_size = 16;
735 Transformer::new(&tiny)
736 }
737
738 fn tensors_map_from_transformer(
742 transformer: &Transformer,
743 ) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
744 let mut map = BTreeMap::new();
745 for (name, t) in transformer.named_parameters() {
746 let len = t.len();
747 let data: Vec<f32> = (0..len).map(|i| i as f32 * 0.001).collect();
748 map.insert(name, (data, vec![len]));
749 }
750 map
751 }
752
753 #[test]
756 fn populate_trainer_from_init_tensors_happy_path() {
757 let mut transformer = tiny_test_transformer();
758 let init_tensors = tensors_map_from_transformer(&transformer);
759 let expected_count = transformer.named_parameters().len();
760 let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
761 assert!(
762 result.is_ok(),
763 "happy-path populate must succeed: {result:?}"
764 );
765 assert_eq!(
766 result.unwrap(),
767 expected_count,
768 "populated count must equal named_parameters().len()"
769 );
770 }
771
772 #[test]
777 fn populate_trainer_from_init_tensors_extra_entries_silently_ignored() {
778 let mut transformer = tiny_test_transformer();
779 let mut init_tensors = tensors_map_from_transformer(&transformer);
780 init_tensors.insert(
782 "model.layers.999.fictitious.weight".to_string(),
783 (vec![0.0; 4], vec![4]),
784 );
785 let expected_count = transformer.named_parameters().len();
786 let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
787 assert!(
788 result.is_ok(),
789 "extra init entries must NOT cause Err: {result:?}"
790 );
791 assert_eq!(result.unwrap(), expected_count);
792 }
793
794 #[test]
798 fn populate_trainer_from_init_tensors_rejects_length_mismatch() {
799 let mut transformer = tiny_test_transformer();
800 let mut init_tensors = tensors_map_from_transformer(&transformer);
801 let any_name = transformer.named_parameters()[0].0.clone();
803 init_tensors.insert(any_name.clone(), (vec![0.0; 7], vec![7]));
804 let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
805 assert!(
806 result.is_err(),
807 "length-mismatch must Err, not silently truncate"
808 );
809 let err = result.unwrap_err();
810 assert!(
811 err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
812 "error must cite falsifier id; got: {err}"
813 );
814 assert!(
815 err.contains(&any_name),
816 "error must name the offending parameter; got: {err}"
817 );
818 assert!(
819 err.contains("init length 7"),
820 "error must report the actual init length; got: {err}"
821 );
822 }
823
824 #[test]
830 fn populate_trainer_from_init_tensors_rejects_missing_required_param() {
831 let mut transformer = tiny_test_transformer();
832 let mut init_tensors = tensors_map_from_transformer(&transformer);
833 let any_name = transformer.named_parameters()[0].0.clone();
835 init_tensors.remove(&any_name);
836 let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
837 assert!(
838 result.is_err(),
839 "missing-required must Err, not silently leave random init"
840 );
841 let err = result.unwrap_err();
842 assert!(
843 err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
844 "error must cite falsifier id; got: {err}"
845 );
846 assert!(
847 err.contains(&any_name),
848 "error must name the missing parameter; got: {err}"
849 );
850 assert!(
851 err.contains("not present in init APR"),
852 "error must say what was missing; got: {err}"
853 );
854 }
855
856 #[test]
863 fn build_shared_trainer_with_init_none_uses_llama370m_shape() {
864 let trainer = build_shared_trainer_with_init(1.0e-4, 128, 42, None, None)
865 .expect("None case must succeed");
866 let model = trainer.borrow();
867 let embed_len = model.model().named_parameters()[0].1.len();
870 let expected_embed_len =
871 Llama370MConfig::VOCAB_SIZE * Llama370MConfig::HIDDEN_DIM;
872 assert_eq!(
873 embed_len, expected_embed_len,
874 "init=None must produce Llama370M-shaped embedding (vocab={} × hidden={})",
875 Llama370MConfig::VOCAB_SIZE,
876 Llama370MConfig::HIDDEN_DIM
877 );
878 }
879
880 #[test]
884 fn build_shared_trainer_with_init_rejects_unpaired_args() {
885 let cfg = TransformerConfig::qwen2_0_5b();
887 let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), None);
888 assert!(
889 result.is_err(),
890 "unpaired (arch=Some, path=None) must Err"
891 );
892 let dummy_path = std::path::PathBuf::from("/dev/null");
894 let result = build_shared_trainer_with_init(1.0e-4, 128, 42, None, Some(&dummy_path));
895 assert!(
896 result.is_err(),
897 "unpaired (arch=None, path=Some) must Err"
898 );
899 }
900
901 #[test]
905 fn build_shared_trainer_with_init_rejects_encoder_family() {
906 let mut encoder_cfg = TransformerConfig::qwen2_0_5b();
907 encoder_cfg.architecture = ModelArchitecture::Encoder;
908 let dummy_path = std::path::PathBuf::from("/nonexistent/encoder.apr");
909 let result =
910 build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&encoder_cfg), Some(&dummy_path));
911 let err = match result {
912 Ok(_) => panic!("encoder family must be rejected before tensor load"),
913 Err(e) => e,
914 };
915 assert!(
916 err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
917 "error must cite falsifier id; got: {err}"
918 );
919 }
920
921 #[test]
926 fn build_shared_trainer_with_init_decoder_family_proceeds_to_tensor_load() {
927 let cfg = TransformerConfig::qwen2_0_5b();
928 let dummy_path = std::path::PathBuf::from("/nonexistent/decoder.apr");
929 let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), Some(&dummy_path));
930 let err = match result {
931 Ok(_) => panic!("missing tensor path must Err"),
932 Err(e) => e,
933 };
934 assert!(
935 err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
936 "decoder family proceeds to tensor load; failure cites INIT-006 not ARCH-007; got: {err}"
937 );
938 assert!(
939 !err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
940 "decoder family must NOT trigger encoder-rejection; got: {err}"
941 );
942 }
943}