1use crate::autograd::matmul_nt;
6use crate::error::{Error, Result};
7use crate::Tensor;
8use provable_contracts_macros::{ensures, requires};
9use std::collections::HashMap;
10use std::path::Path;
11
12use super::block::TransformerBlock;
13use super::config::TransformerConfig;
14use super::embedding::Embedding;
15use super::norm::RMSNorm;
16use super::weights::{load_safetensors_weights, validate_weights, Architecture};
17
18pub struct Transformer {
20 pub config: TransformerConfig,
22 pub embed_tokens: Embedding,
24 pub layers: Vec<TransformerBlock>,
26 pub norm: RMSNorm,
28 pub lm_head: Option<Tensor>,
30}
31
32impl Transformer {
33 pub fn new(config: &TransformerConfig) -> Self {
35 let layers: Vec<TransformerBlock> =
36 (0..config.num_hidden_layers).map(|i| TransformerBlock::new(config, i)).collect();
37
38 Self {
39 config: config.clone(),
40 embed_tokens: Embedding::new(config.vocab_size, config.hidden_size),
41 layers,
42 norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
43 lm_head: None, }
45 }
46
47 pub fn from_params(
55 config: &TransformerConfig,
56 params: &HashMap<String, Tensor>,
57 ) -> Option<Self> {
58 let embed_tokens = Embedding::from_params(
59 params,
60 "model.embed_tokens.weight",
61 config.vocab_size,
62 config.hidden_size,
63 )?;
64
65 let layers: Option<Vec<TransformerBlock>> = (0..config.num_hidden_layers)
66 .map(|i| TransformerBlock::from_params(config, params, i))
67 .collect();
68 let layers = layers?;
69
70 let norm =
71 RMSNorm::from_params(params, "model.norm", config.rms_norm_eps, config.hidden_size)?;
72
73 let lm_head = if let Some(tensor) = params.get("lm_head.weight") {
75 let expected = config.hidden_size * config.vocab_size;
76 if tensor.len() != expected {
77 eprintln!(
78 "[PMAT-329] lm_head.weight: shape mismatch — got {} elements, expected {expected} ({hidden}x{vocab})",
79 tensor.len(),
80 hidden = config.hidden_size,
81 vocab = config.vocab_size,
82 );
83 return None;
84 }
85 Some(tensor.clone())
86 } else {
87 None
88 };
89
90 Some(Self { config: config.clone(), embed_tokens, layers, norm, lm_head })
91 }
92
93 pub fn from_safetensors(
113 model_path: impl AsRef<Path>,
114 config: &TransformerConfig,
115 ) -> Result<Self> {
116 let model_path = model_path.as_ref();
117
118 let weights = load_safetensors_weights(model_path, Architecture::Auto)?;
120
121 validate_weights(&weights, config.num_hidden_layers)?;
123
124 Self::validate_weight_shapes(&weights, config)?;
126
127 Self::validate_weight_values(&weights)?;
129
130 Self::from_params(config, &weights).ok_or_else(|| {
132 Error::ConfigError(
133 "Failed to construct Transformer from loaded weights \
134 (internal from_params returned None after validation passed)"
135 .into(),
136 )
137 })
138 }
139
140 pub fn from_apr(apr_path: impl AsRef<Path>, config: &TransformerConfig) -> Result<Self> {
154 use aprender::serialization::apr::AprReader;
155
156 let apr_path = apr_path.as_ref();
157 let reader = AprReader::open(apr_path).map_err(|e| {
158 Error::ConfigError(format!("Failed to open APR file '{}': {e}", apr_path.display()))
159 })?;
160
161 let is_gguf_names = reader.tensors.iter().any(|t| t.name == "token_embd.weight");
163 if is_gguf_names {
164 eprintln!(
165 "[PMAT-489] Detected GGUF tensor names in APR file, mapping to HF convention"
166 );
167 }
168 let mut weights = HashMap::new();
169 for desc in &reader.tensors {
170 let data = reader.read_tensor_as_f32(&desc.name).map_err(|e| {
171 Error::ConfigError(format!("Failed to read tensor '{}': {e}", desc.name))
172 })?;
173 let mapped_name = if is_gguf_names {
174 super::weights::mapping::map_weight_name(
175 &desc.name,
176 super::weights::Architecture::Gguf,
177 )
178 } else {
179 desc.name.clone()
180 };
181 weights.insert(mapped_name, Tensor::from_vec(data, false));
182 }
183
184 validate_weights(&weights, config.num_hidden_layers)?;
186 Self::validate_weight_shapes(&weights, config)?;
187 Self::validate_weight_values(&weights)?;
188
189 Self::from_params(config, &weights).ok_or_else(|| {
190 Error::ConfigError(
191 "Failed to construct Transformer from APR weights \
192 (from_params returned None after validation passed)"
193 .into(),
194 )
195 })
196 }
197
198 fn validate_weight_shapes(
200 weights: &HashMap<String, Tensor>,
201 config: &TransformerConfig,
202 ) -> Result<()> {
203 let hidden = config.hidden_size;
204 let q_dim = config.q_dim();
205 let kv_hidden = config.num_kv_heads * config.head_dim();
206 let intermediate = config.intermediate_size;
207 let vocab = config.vocab_size;
208
209 let check = |name: &str, expected: usize| -> Result<()> {
211 if let Some(tensor) = weights.get(name) {
212 if tensor.len() != expected {
213 return Err(Error::ConfigError(format!(
214 "Shape mismatch for '{name}': expected {expected} elements, got {}",
215 tensor.len()
216 )));
217 }
218 }
219 Ok(())
221 };
222
223 check("model.embed_tokens.weight", vocab * hidden)?;
225 check("model.norm.weight", hidden)?;
226
227 if weights.contains_key("lm_head.weight") {
229 check("lm_head.weight", vocab * hidden)?;
230 }
231
232 for i in 0..config.num_hidden_layers {
234 let p = format!("model.layers.{i}");
235
236 check(&format!("{p}.input_layernorm.weight"), hidden)?;
238 check(&format!("{p}.post_attention_layernorm.weight"), hidden)?;
239
240 check(&format!("{p}.self_attn.q_proj.weight"), q_dim * hidden)?;
242 check(&format!("{p}.self_attn.k_proj.weight"), kv_hidden * hidden)?;
243 check(&format!("{p}.self_attn.v_proj.weight"), kv_hidden * hidden)?;
244 check(&format!("{p}.self_attn.o_proj.weight"), hidden * q_dim)?;
245
246 check(&format!("{p}.self_attn.q_proj.bias"), q_dim)?;
248 check(&format!("{p}.self_attn.k_proj.bias"), kv_hidden)?;
249 check(&format!("{p}.self_attn.v_proj.bias"), kv_hidden)?;
250
251 check(&format!("{p}.mlp.gate_proj.weight"), hidden * intermediate)?;
253 check(&format!("{p}.mlp.up_proj.weight"), hidden * intermediate)?;
254 check(&format!("{p}.mlp.down_proj.weight"), intermediate * hidden)?;
255 }
256
257 Ok(())
258 }
259
260 fn validate_weight_values(weights: &HashMap<String, Tensor>) -> Result<()> {
262 for (name, tensor) in weights {
263 let data = tensor.data();
264 for (i, &val) in data.iter().enumerate() {
265 if val.is_nan() {
266 return Err(Error::ConfigError(format!(
267 "NaN detected in weight '{name}' at index {i}"
268 )));
269 }
270 if val.is_infinite() {
271 return Err(Error::ConfigError(format!(
272 "Inf detected in weight '{name}' at index {i}"
273 )));
274 }
275 }
276 }
277 Ok(())
278 }
279
280 #[requires(!token_ids.is_empty())]
288 #[ensures(ret.len() == token_ids.len() * self.config.vocab_size)]
289 pub fn forward(&self, token_ids: &[u32]) -> Tensor {
290 contract_pre_embedding_lookup!(token_ids);
291 let seq_len = token_ids.len();
292 let hidden_size = self.config.hidden_size;
293
294 let mut hidden = self.embed_tokens.forward(token_ids);
296
297 for layer in &self.layers {
299 hidden = layer.forward(&hidden, seq_len);
300 }
301
302 let normalized = self.norm.forward_batched(&hidden, seq_len, hidden_size);
304
305 let lm_weight = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
307
308 let result =
310 matmul_nt(&normalized, lm_weight, seq_len, hidden_size, self.config.vocab_size);
311 contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
312 result
313 }
314
315 #[requires(!token_ids.is_empty())]
323 #[ensures(ret.len() == token_ids.len() * self.config.hidden_size)]
324 pub fn forward_hidden(&self, token_ids: &[u32]) -> Tensor {
325 contract_pre_embedding_lookup!(token_ids);
326 let seq_len = token_ids.len();
327 let hidden_size = self.config.hidden_size;
328
329 let mut hidden = self.embed_tokens.forward(token_ids);
331
332 for layer in &self.layers {
334 hidden = layer.forward(&hidden, seq_len);
335 }
336
337 let result = self.norm.forward_batched(&hidden, seq_len, hidden_size);
339 contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
340 result
341 }
342
343 pub fn forward_hidden_with_lora(
356 &self,
357 token_ids: &[u32],
358 lora_layers: &[crate::lora::LoRALayer],
359 ) -> Tensor {
360 contract_pre_embedding_lookup!(token_ids);
361 let seq_len = token_ids.len();
362 let hidden_size = self.config.hidden_size;
363
364 let mut hidden = self.embed_tokens.forward(token_ids);
365
366 for (layer_idx, layer) in self.layers.iter().enumerate() {
367 let norm1 = layer.input_norm.forward_batched(&hidden, seq_len, hidden_size);
368
369 let q_idx = layer_idx * 2;
371 let v_idx = layer_idx * 2 + 1;
372 let attn_out = if v_idx < lora_layers.len() {
373 layer.self_attn.forward_with_lora(
374 &norm1,
375 seq_len,
376 lora_layers[q_idx].lora_a(),
377 lora_layers[q_idx].lora_b(),
378 lora_layers[v_idx].lora_a(),
379 lora_layers[v_idx].lora_b(),
380 lora_layers[q_idx].rank(),
381 lora_layers[q_idx].scale(),
382 )
383 } else {
384 layer.self_attn.forward(&norm1, seq_len)
385 };
386
387 let residual = crate::autograd::add(&hidden, &attn_out);
388 let norm2 = layer.post_attn_norm.forward_batched(&residual, seq_len, hidden_size);
389 let ffn_out = layer.ffn.forward(&norm2, seq_len);
390 hidden = crate::autograd::add(&residual, &ffn_out);
391 }
392
393 let result = self.norm.forward_batched(&hidden, seq_len, hidden_size);
394 contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
395 result
396 }
397
398 pub fn forward_with_lora(
407 &self,
408 token_ids: &[u32],
409 lora_layers: &[crate::lora::LoRALayer],
410 ) -> Tensor {
411 contract_pre_embedding_lookup!(token_ids);
412 let seq_len = token_ids.len();
413 let hidden_size = self.config.hidden_size;
414
415 let hidden = self.forward_hidden_with_lora(token_ids, lora_layers);
416 let lm_weight = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
417 let result = matmul_nt(&hidden, lm_weight, seq_len, hidden_size, self.config.vocab_size);
418 contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
419 result
420 }
421
422 pub fn forward_last(&self, token_ids: &[u32]) -> Tensor {
424 contract_pre_embedding_lookup!(token_ids);
425 let logits = self.forward(token_ids);
426 let seq_len = token_ids.len();
427 let vocab_size = self.config.vocab_size;
428
429 let start = (seq_len - 1) * vocab_size;
431 let end = start + vocab_size;
432 let last_logits: Vec<f32> =
433 logits.data().as_slice().expect("logits must be contiguous")[start..end].to_vec();
434
435 let result = Tensor::from_vec(last_logits, logits.requires_grad());
436 contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
437 result
438 }
439
440 pub fn parameters(&self) -> Vec<&Tensor> {
442 let mut params = vec![&self.embed_tokens.weight, &self.norm.weight];
443 for layer in &self.layers {
444 params.extend(layer.parameters());
445 }
446 if let Some(lm_head) = &self.lm_head {
447 params.push(lm_head);
448 }
449 params
450 }
451
452 pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
454 let mut params: Vec<&mut Tensor> = Vec::new();
455 params.push(&mut self.embed_tokens.weight);
456 params.push(&mut self.norm.weight);
457 for layer in &mut self.layers {
458 params.extend(layer.parameters_mut());
459 }
460 if let Some(lm_head) = &mut self.lm_head {
461 params.push(lm_head);
462 }
463 params
464 }
465
466 pub fn config(&self) -> &TransformerConfig {
468 &self.config
469 }
470
471 pub fn embed_token(&self, token_id: u32) -> Vec<f32> {
473 let w = self.embed_tokens.weight.data();
474 let data = w.as_slice().expect("contiguous embedding");
475 let h = self.config.hidden_size;
476 let offset = (token_id as usize) * h;
477 data[offset..offset + h].to_vec()
478 }
479
480 pub fn output_norm_weight_slice(&self) -> &[f32] {
482 self.norm.weight.data().as_slice().expect("contiguous norm weight")
483 }
484
485 pub fn lm_head_weight_slice(&self) -> &[f32] {
487 let w = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
488 w.data().as_slice().expect("contiguous lm_head")
489 }
490
491 pub fn lm_head_weight(&self) -> &Tensor {
496 self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight)
497 }
498
499 pub fn named_parameters(&self) -> Vec<(String, &Tensor)> {
505 let mut params = vec![
506 ("model.embed_tokens.weight".to_string(), &self.embed_tokens.weight),
507 ("model.norm.weight".to_string(), &self.norm.weight),
508 ];
509 for layer in &self.layers {
510 params.extend(layer.named_parameters());
511 }
512 if let Some(ref lm_head) = self.lm_head {
513 params.push(("lm_head.weight".to_string(), lm_head));
514 }
515 params
516 }
517
518 pub fn set_named_parameter(&mut self, name: &str, value: Tensor) -> bool {
522 if name == "model.embed_tokens.weight" {
523 self.embed_tokens.weight = value;
524 return true;
525 }
526 if name == "model.norm.weight" {
527 self.norm.weight = value;
528 return true;
529 }
530 if name == "lm_head.weight" {
531 self.lm_head = Some(value);
532 return true;
533 }
534 if let Some(rest) = name.strip_prefix("model.layers.") {
536 if let Some(dot_pos) = rest.find('.') {
537 if let Ok(idx) = rest[..dot_pos].parse::<usize>() {
538 if idx < self.layers.len() {
539 let suffix = &rest[dot_pos + 1..];
540 return self.layers[idx].set_named_parameter(suffix, value);
541 }
542 }
543 }
544 }
545 false
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_transformer_tiny_forward() {
555 let config = TransformerConfig::tiny();
556 let transformer = Transformer::new(&config);
557 let tokens = vec![1, 2, 3];
558 let logits = transformer.forward(&tokens);
559 assert_eq!(logits.len(), 3 * config.vocab_size);
560 }
561
562 #[test]
579 fn falsify_apr_pretrain_arch_004_gqa_7_1_forward_pass_smoke() {
580 let config = TransformerConfig {
583 hidden_size: 112,
584 num_attention_heads: 14,
585 num_kv_heads: 2,
586 intermediate_size: 64,
587 num_hidden_layers: 1,
588 vocab_size: 256,
589 max_position_embeddings: 512,
590 rms_norm_eps: 1e-6,
591 rope_theta: 1_000_000.0, use_bias: true, head_dim_override: None, architecture: crate::transformer::config::ModelArchitecture::Decoder,
595 hf_architecture: None,
596 hf_model_type: None,
597 tie_word_embeddings: true, };
599
600 assert_eq!(
605 config.num_attention_heads / config.num_kv_heads,
606 7,
607 "GQA-7:1 ratio must be 14/2=7 (Qwen2.5-0.5B canonical)"
608 );
609
610 let transformer = Transformer::new(&config);
611 let tokens = vec![1u32, 2, 3, 4]; let logits = transformer.forward(&tokens);
613
614 assert_eq!(
616 logits.len(),
617 4 * config.vocab_size,
618 "GQA-7:1 forward must return seq_len * vocab_size logits"
619 );
620
621 assert!(
627 logits.data().iter().all(|&v| v.is_finite()),
628 "GQA-7:1 forward must produce all-finite logits — silent NaN \
629 would corrupt the §49 fine-tune trajectory before FALSIFY-006 \
630 (init_loss < 6.0) could measure it"
631 );
632 }
633
634 #[test]
635 fn test_transformer_tiny_forward_last() {
636 let config = TransformerConfig::tiny();
637 let transformer = Transformer::new(&config);
638 let tokens = vec![1, 2, 3];
639 let logits = transformer.forward_last(&tokens);
640 assert_eq!(logits.len(), config.vocab_size);
641 }
642
643 #[test]
644 fn test_transformer_parameters() {
645 let config = TransformerConfig::tiny();
646 let transformer = Transformer::new(&config);
647 let params = transformer.parameters();
648 assert_eq!(params.len(), 20);
651 }
652
653 #[test]
654 fn test_transformer_config_accessor() {
655 let config = TransformerConfig::tiny();
656 let transformer = Transformer::new(&config);
657 assert_eq!(transformer.config().hidden_size, config.hidden_size);
658 assert_eq!(transformer.config().vocab_size, config.vocab_size);
659 }
660
661 #[test]
662 fn test_transformer_single_token() {
663 let config = TransformerConfig::tiny();
664 let transformer = Transformer::new(&config);
665 let tokens = vec![42];
666 let logits = transformer.forward(&tokens);
667 assert_eq!(logits.len(), config.vocab_size);
668 }
669
670 #[test]
671 fn test_output_finite_values() {
672 let config = TransformerConfig::tiny();
673 let transformer = Transformer::new(&config);
674 let tokens = vec![1, 2, 3, 4, 5];
675 let logits = transformer.forward(&tokens);
676 assert!(logits.data().iter().all(|&v| v.is_finite()));
678 }
679
680 #[test]
681 fn test_transformer_empty_lm_head_uses_tied_weights() {
682 let config = TransformerConfig::tiny();
683 let transformer = Transformer::new(&config);
684 assert!(transformer.lm_head.is_none());
686 let tokens = vec![1, 2];
688 let logits = transformer.forward(&tokens);
689 assert_eq!(logits.len(), 2 * config.vocab_size);
690 }
691
692 #[test]
693 fn test_from_params_returns_none_on_missing() {
694 let config = TransformerConfig::tiny();
695 let params: HashMap<String, Tensor> = HashMap::new();
696 let result = Transformer::from_params(&config, ¶ms);
697 assert!(result.is_none());
698 }
699
700 #[test]
701 fn test_transformer_from_params_with_lm_head() {
702 let config = TransformerConfig::tiny();
703 let hidden_size = config.hidden_size;
704 let vocab_size = config.vocab_size;
705 let kv_hidden_size = config.num_kv_heads * config.head_dim();
706 let intermediate_size = config.intermediate_size;
707
708 let mut params = HashMap::new();
709
710 params.insert(
712 "model.embed_tokens.weight".to_string(),
713 Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
714 );
715
716 for layer_idx in 0..config.num_hidden_layers {
718 let prefix = format!("model.layers.{layer_idx}");
719 params.insert(
720 format!("{prefix}.input_layernorm.weight"),
721 Tensor::from_vec(vec![1.0; hidden_size], true),
722 );
723 params.insert(
724 format!("{prefix}.self_attn.q_proj.weight"),
725 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
726 );
727 params.insert(
728 format!("{prefix}.self_attn.k_proj.weight"),
729 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
730 );
731 params.insert(
732 format!("{prefix}.self_attn.v_proj.weight"),
733 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
734 );
735 params.insert(
736 format!("{prefix}.self_attn.o_proj.weight"),
737 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
738 );
739 params.insert(
740 format!("{prefix}.post_attention_layernorm.weight"),
741 Tensor::from_vec(vec![1.0; hidden_size], true),
742 );
743 params.insert(
744 format!("{prefix}.mlp.gate_proj.weight"),
745 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
746 );
747 params.insert(
748 format!("{prefix}.mlp.up_proj.weight"),
749 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
750 );
751 params.insert(
752 format!("{prefix}.mlp.down_proj.weight"),
753 Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
754 );
755 }
756
757 params.insert(
759 "model.norm.weight".to_string(),
760 Tensor::from_vec(vec![1.0; hidden_size], true),
761 );
762
763 params.insert(
765 "lm_head.weight".to_string(),
766 Tensor::from_vec(vec![0.1; hidden_size * vocab_size], true),
767 );
768
769 let transformer = Transformer::from_params(&config, ¶ms);
770 assert!(transformer.is_some());
771 let transformer = transformer.expect("operation should succeed");
772 assert!(transformer.lm_head.is_some());
773 assert_eq!(transformer.layers.len(), config.num_hidden_layers);
774 }
775
776 #[test]
777 fn test_transformer_from_params_without_lm_head() {
778 let config = TransformerConfig::tiny();
779 let hidden_size = config.hidden_size;
780 let vocab_size = config.vocab_size;
781 let kv_hidden_size = config.num_kv_heads * config.head_dim();
782 let intermediate_size = config.intermediate_size;
783
784 let mut params = HashMap::new();
785
786 params.insert(
788 "model.embed_tokens.weight".to_string(),
789 Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
790 );
791
792 for layer_idx in 0..config.num_hidden_layers {
794 let prefix = format!("model.layers.{layer_idx}");
795 params.insert(
796 format!("{prefix}.input_layernorm.weight"),
797 Tensor::from_vec(vec![1.0; hidden_size], true),
798 );
799 params.insert(
800 format!("{prefix}.self_attn.q_proj.weight"),
801 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
802 );
803 params.insert(
804 format!("{prefix}.self_attn.k_proj.weight"),
805 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
806 );
807 params.insert(
808 format!("{prefix}.self_attn.v_proj.weight"),
809 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
810 );
811 params.insert(
812 format!("{prefix}.self_attn.o_proj.weight"),
813 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
814 );
815 params.insert(
816 format!("{prefix}.post_attention_layernorm.weight"),
817 Tensor::from_vec(vec![1.0; hidden_size], true),
818 );
819 params.insert(
820 format!("{prefix}.mlp.gate_proj.weight"),
821 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
822 );
823 params.insert(
824 format!("{prefix}.mlp.up_proj.weight"),
825 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
826 );
827 params.insert(
828 format!("{prefix}.mlp.down_proj.weight"),
829 Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
830 );
831 }
832
833 params.insert(
835 "model.norm.weight".to_string(),
836 Tensor::from_vec(vec![1.0; hidden_size], true),
837 );
838
839 let transformer = Transformer::from_params(&config, ¶ms);
840 assert!(transformer.is_some());
841 let transformer = transformer.expect("operation should succeed");
842 assert!(transformer.lm_head.is_none()); }
844
845 #[test]
846 fn test_transformer_parameters_with_lm_head() {
847 let config = TransformerConfig::tiny();
848 let mut transformer = Transformer::new(&config);
849
850 transformer.lm_head =
852 Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
853
854 let params = transformer.parameters();
855 assert_eq!(params.len(), 21);
858 }
859
860 #[test]
861 fn test_transformer_forward_with_lm_head() {
862 let config = TransformerConfig::tiny();
863 let mut transformer = Transformer::new(&config);
864
865 transformer.lm_head =
867 Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
868
869 let tokens = vec![1, 2, 3];
870 let logits = transformer.forward(&tokens);
871 assert_eq!(logits.len(), 3 * config.vocab_size);
872 assert!(logits.data().iter().all(|&v| v.is_finite()));
873 }
874
875 #[test]
898 fn falsify_l1e_from_params_rejects_wrong_shape_lm_head() {
899 let config = TransformerConfig::tiny();
900 let hidden_size = config.hidden_size;
901 let vocab_size = config.vocab_size;
902 let kv_hidden_size = config.num_kv_heads * config.head_dim();
903 let intermediate_size = config.intermediate_size;
904
905 let mut params = HashMap::new();
906
907 params.insert(
909 "model.embed_tokens.weight".to_string(),
910 Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
911 );
912 for layer_idx in 0..config.num_hidden_layers {
913 let prefix = format!("model.layers.{layer_idx}");
914 params.insert(
915 format!("{prefix}.input_layernorm.weight"),
916 Tensor::from_vec(vec![1.0; hidden_size], true),
917 );
918 params.insert(
919 format!("{prefix}.self_attn.q_proj.weight"),
920 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
921 );
922 params.insert(
923 format!("{prefix}.self_attn.k_proj.weight"),
924 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
925 );
926 params.insert(
927 format!("{prefix}.self_attn.v_proj.weight"),
928 Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
929 );
930 params.insert(
931 format!("{prefix}.self_attn.o_proj.weight"),
932 Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
933 );
934 params.insert(
935 format!("{prefix}.post_attention_layernorm.weight"),
936 Tensor::from_vec(vec![1.0; hidden_size], true),
937 );
938 params.insert(
939 format!("{prefix}.mlp.gate_proj.weight"),
940 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
941 );
942 params.insert(
943 format!("{prefix}.mlp.up_proj.weight"),
944 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
945 );
946 params.insert(
947 format!("{prefix}.mlp.down_proj.weight"),
948 Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
949 );
950 }
951 params.insert(
952 "model.norm.weight".to_string(),
953 Tensor::from_vec(vec![1.0; hidden_size], true),
954 );
955
956 params.insert("lm_head.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
958
959 let transformer = Transformer::from_params(&config, ¶ms);
960 assert!(
962 transformer.is_none(),
963 "FALSIFY-L1e: PMAT-329 fix — from_params MUST reject wrong-shape lm_head"
964 );
965 }
966
967 #[test]
972 fn falsify_l2e_tied_embeddings_produce_correct_logit_dims() {
973 let config = TransformerConfig::tiny();
974 let transformer = Transformer::new(&config);
975 assert!(transformer.lm_head.is_none(), "Default should use tied embeddings");
976
977 let tokens = vec![1, 2, 3];
978 let logits = transformer.forward(&tokens);
979 assert_eq!(
980 logits.len(),
981 3 * config.vocab_size,
982 "FALSIFY-L2e: Tied embedding logits must be seq_len * vocab_size"
983 );
984
985 let data = logits.data();
987 let nan_count = data.iter().filter(|v| v.is_nan()).count();
988 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
989 assert_eq!(nan_count, 0, "FALSIFY-L2e: Tied logits must not contain NaN");
990 assert_eq!(inf_count, 0, "FALSIFY-L2e: Tied logits must not contain Inf");
991 }
992
993 #[test]
995 fn falsify_l3e_separate_lm_head_produces_correct_logit_dims() {
996 let config = TransformerConfig::tiny();
997 let mut transformer = Transformer::new(&config);
998 transformer.lm_head =
999 Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
1000
1001 let tokens = vec![1, 2, 3];
1002 let logits = transformer.forward(&tokens);
1003 assert_eq!(
1004 logits.len(),
1005 3 * config.vocab_size,
1006 "FALSIFY-L3e: Separate lm_head logits must be seq_len * vocab_size"
1007 );
1008 let data = logits.data();
1009 assert!(
1010 data.iter().all(|v| v.is_finite()),
1011 "FALSIFY-L3e: Separate lm_head logits must all be finite"
1012 );
1013 }
1014
1015 #[test]
1020 fn falsify_l4e_lm_head_in_parameter_list() {
1021 let config = TransformerConfig::tiny();
1022 let mut transformer = Transformer::new(&config);
1023
1024 let n_without = transformer.parameters().len();
1026
1027 transformer.lm_head =
1029 Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
1030 let n_with = transformer.parameters().len();
1031 assert_eq!(
1032 n_with,
1033 n_without + 1,
1034 "FALSIFY-L4e: lm_head must be included in parameters() — optimizer needs it"
1035 );
1036
1037 let n_mut = transformer.parameters_mut().len();
1039 assert_eq!(
1040 n_mut, n_with,
1041 "FALSIFY-L4e: parameters_mut() must include lm_head for gradient updates"
1042 );
1043 }
1044
1045 #[test]
1050 fn falsify_l5e_forward_last_correct_size() {
1051 let config = TransformerConfig::tiny();
1052 let transformer = Transformer::new(&config);
1053
1054 let tokens = vec![1, 2, 3, 4, 5];
1055 let logits = transformer.forward_last(&tokens);
1056 assert_eq!(
1057 logits.len(),
1058 config.vocab_size,
1059 "FALSIFY-L5e: forward_last must return exactly vocab_size logits"
1060 );
1061 let data = logits.data();
1062 assert!(
1063 data.iter().all(|v| v.is_finite()),
1064 "FALSIFY-L5e: forward_last logits must all be finite"
1065 );
1066 }
1067
1068 #[test]
1069 fn test_causal_lm_loss_backward() {
1070 use crate::train::CausalLMLoss;
1071 use crate::train::LossFn;
1072
1073 let vocab_size = 100;
1074 let seq_len = 3;
1075 let loss_fn = CausalLMLoss::new(vocab_size);
1076
1077 let logits = Tensor::from_vec(
1079 (0..seq_len * vocab_size).map(|i| (i as f32 * 0.01).sin()).collect(),
1080 true,
1081 );
1082
1083 let targets = Tensor::from_vec(vec![5.0, 10.0, 15.0], false);
1085
1086 let mut loss = loss_fn.forward(&logits, &targets);
1087
1088 crate::autograd::backward(&mut loss, None);
1090
1091 assert!(loss.data()[0] > 0.0);
1093 assert!(loss.data()[0].is_finite());
1094
1095 assert!(logits.grad().is_some());
1097 let grad = logits.grad().expect("gradient should be available");
1098 assert!(grad.iter().all(|&v| v.is_finite()));
1099 }
1100
1101 #[test]
1122 fn falsify_emb_003_tied_weight_sharing() {
1123 let config = TransformerConfig::tiny();
1124 let transformer = Transformer::new(&config);
1125
1126 assert!(transformer.lm_head.is_none());
1128
1129 let lm_weight = transformer.lm_head.as_ref().unwrap_or(&transformer.embed_tokens.weight);
1131 let embed_weight = &transformer.embed_tokens.weight;
1132
1133 assert!(
1135 std::ptr::eq(lm_weight, embed_weight),
1136 "FALSIFIED EMB-003: tied lm_head must be same object as embed_tokens.weight"
1137 );
1138 }
1139
1140 #[test]
1142 fn falsify_te_001_output_shape() {
1143 let config = TransformerConfig::tiny();
1144 let transformer = Transformer::new(&config);
1145
1146 for seq_len in [1, 3, 10] {
1147 let tokens: Vec<u32> = (0..seq_len).collect();
1148 let logits = transformer.forward(&tokens);
1149 assert_eq!(
1150 logits.len(),
1151 seq_len as usize * config.vocab_size,
1152 "FALSIFIED TE-001: output shape for seq_len={seq_len}"
1153 );
1154 }
1155 }
1156
1157 #[test]
1164 fn falsify_te_002_tied_equivalence() {
1165 let config = TransformerConfig::tiny();
1166 let transformer = Transformer::new(&config);
1167
1168 let tokens = vec![0u32, 3, 7, 15, 42];
1170 let tied_logits = transformer.forward(&tokens);
1171
1172 let hidden = transformer.forward_hidden(&tokens);
1174 let w_clone = transformer.embed_tokens.weight.clone();
1175 let explicit_logits =
1176 matmul_nt(&hidden, &w_clone, tokens.len(), config.hidden_size, config.vocab_size);
1177
1178 let tied_data = tied_logits.data();
1179 let explicit_data = explicit_logits.data();
1180
1181 assert_eq!(
1182 tied_data.len(),
1183 explicit_data.len(),
1184 "FALSIFIED TE-002: output lengths differ: {} vs {}",
1185 tied_data.len(),
1186 explicit_data.len()
1187 );
1188
1189 for (i, (&t, &e)) in tied_data.iter().zip(explicit_data.iter()).enumerate() {
1190 assert!(
1191 (t - e).abs() < 1e-6,
1192 "FALSIFIED TE-002: tied[{i}] = {t} != explicit[{i}] = {e}"
1193 );
1194 }
1195 }
1196
1197 #[test]
1201 fn falsify_te_003_no_extra_params() {
1202 let config = TransformerConfig::tiny();
1203 let tied = Transformer::new(&config);
1204 let tied_count = tied.parameters().len();
1205
1206 let mut untied = Transformer::new(&config);
1207 untied.lm_head =
1208 Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
1209 let untied_count = untied.parameters().len();
1210
1211 assert_eq!(
1212 untied_count,
1213 tied_count + 1,
1214 "FALSIFIED TE-003: tied model must have exactly 1 fewer param than untied"
1215 );
1216 }
1217
1218 #[test]
1220 fn falsify_te_004_finite_output() {
1221 let config = TransformerConfig::tiny();
1222 let transformer = Transformer::new(&config);
1223 let tokens = vec![0u32, 5, 10, 50, 99];
1224 let logits = transformer.forward(&tokens);
1225 let data = logits.data();
1226
1227 let nan_count = data.iter().filter(|v| v.is_nan()).count();
1228 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
1229
1230 assert_eq!(
1231 nan_count, 0,
1232 "FALSIFIED TE-004: tied embedding output contains {nan_count} NaN values"
1233 );
1234 assert_eq!(
1235 inf_count, 0,
1236 "FALSIFIED TE-004: tied embedding output contains {inf_count} Inf values"
1237 );
1238 }
1239
1240 mod te_proptest_falsify {
1257 use super::*;
1258 use proptest::prelude::*;
1259
1260 proptest! {
1263 #![proptest_config(ProptestConfig::with_cases(50))]
1264 #[test]
1265 fn falsify_te_001_prop_output_shape(
1266 seq_len in 1_usize..32,
1267 ) {
1268 let config = TransformerConfig::tiny();
1269 let transformer = Transformer::new(&config);
1270 let tokens: Vec<u32> = (0..seq_len).map(|i| (i % config.vocab_size) as u32).collect();
1271 let logits = transformer.forward(&tokens);
1272 prop_assert_eq!(
1273 logits.len(),
1274 seq_len * config.vocab_size,
1275 "FALSIFIED TE-001-prop: seq_len={}, got len={}", seq_len, logits.len()
1276 );
1277 }
1278 }
1279
1280 proptest! {
1282 #![proptest_config(ProptestConfig::with_cases(20))]
1283 #[test]
1284 fn falsify_te_002_prop_tied_equivalence(
1285 token_ids in proptest::collection::vec(0_u32..999, 1..8),
1286 ) {
1287 let config = TransformerConfig::tiny();
1288 let transformer = Transformer::new(&config);
1289
1290 let tied_logits = transformer.forward(&token_ids);
1291 let hidden = transformer.forward_hidden(&token_ids);
1292 let w_clone = transformer.embed_tokens.weight.clone();
1293 let explicit_logits = matmul_nt(
1294 &hidden, &w_clone,
1295 token_ids.len(), config.hidden_size, config.vocab_size,
1296 );
1297
1298 let tied_data = tied_logits.data();
1299 let explicit_data = explicit_logits.data();
1300 prop_assert_eq!(tied_data.len(), explicit_data.len());
1301
1302 for (i, (&t, &e)) in tied_data.iter().zip(explicit_data.iter()).enumerate() {
1303 prop_assert!(
1304 (t - e).abs() < 1e-5,
1305 "FALSIFIED TE-002-prop: tied[{}]={} != explicit[{}]={}",
1306 i, t, i, e
1307 );
1308 }
1309 }
1310 }
1311
1312 proptest! {
1314 #![proptest_config(ProptestConfig::with_cases(30))]
1315 #[test]
1316 fn falsify_te_004_prop_finite(
1317 token_ids in proptest::collection::vec(0_u32..999, 1..16),
1318 ) {
1319 let config = TransformerConfig::tiny();
1320 let transformer = Transformer::new(&config);
1321 let logits = transformer.forward(&token_ids);
1322 let data = logits.data();
1323
1324 for (i, &v) in data.iter().enumerate() {
1325 prop_assert!(
1326 v.is_finite(),
1327 "FALSIFIED TE-004-prop: logits[{}]={} non-finite (n_tokens={})",
1328 i, v, token_ids.len()
1329 );
1330 }
1331 }
1332 }
1333 }
1334
1335 #[test]
1356 fn falsify_pipe_001_embed_tied_softmax_pipeline() {
1357 let config = TransformerConfig::tiny();
1358 let transformer = Transformer::new(&config);
1359
1360 let tokens = vec![0u32, 3, 7, 15, 42];
1361 let seq_len = tokens.len();
1362 let vocab_size = config.vocab_size;
1363
1364 let logits = transformer.forward(&tokens);
1366 let logits_data = logits.data();
1367
1368 assert_eq!(
1370 logits_data.len(),
1371 seq_len * vocab_size,
1372 "FALSIFIED PIPE-001/TE-001: logits len={} != seq_len({seq_len}) * vocab({vocab_size})",
1373 logits_data.len()
1374 );
1375
1376 for (i, &l) in logits_data.iter().enumerate() {
1378 assert!(l.is_finite(), "FALSIFIED PIPE-001/TE-004: logits[{i}] = {l} not finite");
1379 }
1380
1381 let logits_slice = logits_data.as_slice().expect("operation should succeed");
1383 for row in 0..seq_len {
1384 let start = row * vocab_size;
1385 let end = start + vocab_size;
1386 let row_logits = &logits_slice[start..end];
1387
1388 let max_val = row_logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1390 let exps: Vec<f32> = row_logits.iter().map(|&x| (x - max_val).exp()).collect();
1391 let sum: f32 = exps.iter().sum();
1392 let probs: Vec<f32> = exps.iter().map(|&e| e / sum).collect();
1393
1394 let prob_sum: f32 = probs.iter().sum();
1396 assert!(
1397 (prob_sum - 1.0).abs() < 1e-4,
1398 "FALSIFIED PIPE-001/SM-001: row {row} prob sum={prob_sum}"
1399 );
1400
1401 for (i, &p) in probs.iter().enumerate() {
1403 assert!(p >= 0.0, "FALSIFIED PIPE-001/SM-002: row {row} prob[{i}]={p} negative");
1404 }
1405
1406 let logit_argmax = row_logits
1408 .iter()
1409 .enumerate()
1410 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
1411 .expect("operation should succeed")
1412 .0;
1413 let prob_argmax = probs
1414 .iter()
1415 .enumerate()
1416 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
1417 .expect("operation should succeed")
1418 .0;
1419 assert_eq!(
1420 logit_argmax, prob_argmax,
1421 "FALSIFIED PIPE-001/SM-003: row {row} argmax changed {logit_argmax} → {prob_argmax}"
1422 );
1423 }
1424 }
1425
1426 mod safetensors_tests {
1435 use super::*;
1436 use safetensors::serialize;
1437 use safetensors::tensor::{Dtype, TensorView};
1438 use tempfile::TempDir;
1439
1440 fn create_tiny_safetensors(dir: &std::path::Path) -> std::path::PathBuf {
1443 let config = TransformerConfig::tiny();
1444 let hidden = config.hidden_size;
1445 let kv_hidden = config.num_kv_heads * config.head_dim();
1446 let intermediate = config.intermediate_size;
1447 let vocab = config.vocab_size;
1448
1449 let mut tensors_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1450
1451 let make_f32 = |n: usize, val: f32| -> Vec<u8> {
1453 std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
1454 };
1455
1456 tensors_data.push((
1458 "model.embed_tokens.weight".to_string(),
1459 make_f32(vocab * hidden, 0.01),
1460 vec![vocab, hidden],
1461 ));
1462
1463 tensors_data.push((
1465 "model.norm.weight".to_string(),
1466 make_f32(hidden, 1.0),
1467 vec![hidden],
1468 ));
1469
1470 for i in 0..config.num_hidden_layers {
1472 let p = format!("model.layers.{i}");
1473
1474 tensors_data.push((
1476 format!("{p}.input_layernorm.weight"),
1477 make_f32(hidden, 1.0),
1478 vec![hidden],
1479 ));
1480 tensors_data.push((
1481 format!("{p}.post_attention_layernorm.weight"),
1482 make_f32(hidden, 1.0),
1483 vec![hidden],
1484 ));
1485
1486 tensors_data.push((
1488 format!("{p}.self_attn.q_proj.weight"),
1489 make_f32(hidden * hidden, 0.01),
1490 vec![hidden, hidden],
1491 ));
1492 tensors_data.push((
1493 format!("{p}.self_attn.k_proj.weight"),
1494 make_f32(hidden * kv_hidden, 0.01),
1495 vec![kv_hidden, hidden],
1496 ));
1497 tensors_data.push((
1498 format!("{p}.self_attn.v_proj.weight"),
1499 make_f32(hidden * kv_hidden, 0.01),
1500 vec![kv_hidden, hidden],
1501 ));
1502 tensors_data.push((
1503 format!("{p}.self_attn.o_proj.weight"),
1504 make_f32(hidden * hidden, 0.01),
1505 vec![hidden, hidden],
1506 ));
1507
1508 tensors_data.push((
1510 format!("{p}.mlp.gate_proj.weight"),
1511 make_f32(hidden * intermediate, 0.01),
1512 vec![intermediate, hidden],
1513 ));
1514 tensors_data.push((
1515 format!("{p}.mlp.up_proj.weight"),
1516 make_f32(hidden * intermediate, 0.01),
1517 vec![intermediate, hidden],
1518 ));
1519 tensors_data.push((
1520 format!("{p}.mlp.down_proj.weight"),
1521 make_f32(intermediate * hidden, 0.01),
1522 vec![hidden, intermediate],
1523 ));
1524 }
1525
1526 let views: Vec<TensorView<'_>> = tensors_data
1528 .iter()
1529 .map(|(_, bytes, shape)| {
1530 TensorView::new(Dtype::F32, shape.clone(), bytes).expect("valid tensor view")
1531 })
1532 .collect();
1533
1534 let named_views: Vec<(&str, &TensorView<'_>)> = tensors_data
1535 .iter()
1536 .zip(views.iter())
1537 .map(|((name, _, _), view)| (name.as_str(), view))
1538 .collect();
1539
1540 let file_path = dir.join("model.safetensors");
1541 let serialized =
1542 serialize(named_views, None::<std::collections::HashMap<String, String>>)
1543 .expect("serialize safetensors");
1544 std::fs::write(&file_path, serialized).expect("write safetensors file");
1545 file_path
1546 }
1547
1548 fn create_tiny_bf16_safetensors(dir: &std::path::Path) -> std::path::PathBuf {
1550 let config = TransformerConfig::tiny();
1551 let hidden = config.hidden_size;
1552 let kv_hidden = config.num_kv_heads * config.head_dim();
1553 let intermediate = config.intermediate_size;
1554 let vocab = config.vocab_size;
1555
1556 let mut tensors_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1557
1558 let make_bf16 = |n: usize, val: f32| -> Vec<u8> {
1560 std::iter::repeat_n(half::bf16::from_f32(val), n)
1561 .flat_map(half::bf16::to_le_bytes)
1562 .collect()
1563 };
1564
1565 tensors_data.push((
1567 "model.embed_tokens.weight".to_string(),
1568 make_bf16(vocab * hidden, 0.01),
1569 vec![vocab, hidden],
1570 ));
1571
1572 tensors_data.push((
1574 "model.norm.weight".to_string(),
1575 make_bf16(hidden, 1.0),
1576 vec![hidden],
1577 ));
1578
1579 for i in 0..config.num_hidden_layers {
1581 let p = format!("model.layers.{i}");
1582
1583 tensors_data.push((
1584 format!("{p}.input_layernorm.weight"),
1585 make_bf16(hidden, 1.0),
1586 vec![hidden],
1587 ));
1588 tensors_data.push((
1589 format!("{p}.post_attention_layernorm.weight"),
1590 make_bf16(hidden, 1.0),
1591 vec![hidden],
1592 ));
1593 tensors_data.push((
1594 format!("{p}.self_attn.q_proj.weight"),
1595 make_bf16(hidden * hidden, 0.01),
1596 vec![hidden, hidden],
1597 ));
1598 tensors_data.push((
1599 format!("{p}.self_attn.k_proj.weight"),
1600 make_bf16(hidden * kv_hidden, 0.01),
1601 vec![kv_hidden, hidden],
1602 ));
1603 tensors_data.push((
1604 format!("{p}.self_attn.v_proj.weight"),
1605 make_bf16(hidden * kv_hidden, 0.01),
1606 vec![kv_hidden, hidden],
1607 ));
1608 tensors_data.push((
1609 format!("{p}.self_attn.o_proj.weight"),
1610 make_bf16(hidden * hidden, 0.01),
1611 vec![hidden, hidden],
1612 ));
1613 tensors_data.push((
1614 format!("{p}.mlp.gate_proj.weight"),
1615 make_bf16(hidden * intermediate, 0.01),
1616 vec![intermediate, hidden],
1617 ));
1618 tensors_data.push((
1619 format!("{p}.mlp.up_proj.weight"),
1620 make_bf16(hidden * intermediate, 0.01),
1621 vec![intermediate, hidden],
1622 ));
1623 tensors_data.push((
1624 format!("{p}.mlp.down_proj.weight"),
1625 make_bf16(intermediate * hidden, 0.01),
1626 vec![hidden, intermediate],
1627 ));
1628 }
1629
1630 let views: Vec<TensorView<'_>> = tensors_data
1631 .iter()
1632 .map(|(_, bytes, shape)| {
1633 TensorView::new(Dtype::BF16, shape.clone(), bytes).expect("valid tensor view")
1634 })
1635 .collect();
1636
1637 let named_views: Vec<(&str, &TensorView<'_>)> = tensors_data
1638 .iter()
1639 .zip(views.iter())
1640 .map(|((name, _, _), view)| (name.as_str(), view))
1641 .collect();
1642
1643 let file_path = dir.join("model.safetensors");
1644 let serialized =
1645 serialize(named_views, None::<std::collections::HashMap<String, String>>)
1646 .expect("serialize safetensors");
1647 std::fs::write(&file_path, serialized).expect("write safetensors file");
1648 file_path
1649 }
1650
1651 #[test]
1656 fn test_ssc024_from_safetensors_f32_success() {
1657 let dir = TempDir::new().expect("create temp dir");
1658 create_tiny_safetensors(dir.path());
1659 let config = TransformerConfig::tiny();
1660
1661 let result = Transformer::from_safetensors(dir.path(), &config);
1662 assert!(
1663 result.is_ok(),
1664 "from_safetensors should succeed: {}",
1665 result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
1666 );
1667
1668 let transformer = result.expect("validated above");
1669 assert_eq!(transformer.layers.len(), config.num_hidden_layers);
1670 assert!(transformer.lm_head.is_none()); }
1672
1673 #[test]
1674 fn test_ssc024_from_safetensors_bf16_conversion() {
1675 let dir = TempDir::new().expect("create temp dir");
1676 create_tiny_bf16_safetensors(dir.path());
1677 let config = TransformerConfig::tiny();
1678
1679 let result = Transformer::from_safetensors(dir.path(), &config);
1680 assert!(
1681 result.is_ok(),
1682 "BF16 loading should succeed: {}",
1683 result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
1684 );
1685
1686 let transformer = result.expect("validated above");
1687 assert_eq!(transformer.layers.len(), config.num_hidden_layers);
1688
1689 let tokens = vec![1u32, 2, 3];
1691 let logits = transformer.forward(&tokens);
1692 assert_eq!(logits.len(), 3 * config.vocab_size);
1693 assert!(
1694 logits.data().iter().all(|v| v.is_finite()),
1695 "BF16-loaded model should produce finite outputs"
1696 );
1697 }
1698
1699 #[test]
1700 fn test_ssc024_from_safetensors_single_file_path() {
1701 let dir = TempDir::new().expect("create temp dir");
1702 let file_path = create_tiny_safetensors(dir.path());
1703 let config = TransformerConfig::tiny();
1704
1705 let result = Transformer::from_safetensors(&file_path, &config);
1707 assert!(
1708 result.is_ok(),
1709 "Direct file path should work: {}",
1710 result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
1711 );
1712 }
1713
1714 #[test]
1715 fn test_ssc024_loaded_model_forward_produces_finite() {
1716 let dir = TempDir::new().expect("create temp dir");
1717 create_tiny_safetensors(dir.path());
1718 let config = TransformerConfig::tiny();
1719
1720 let transformer =
1721 Transformer::from_safetensors(dir.path(), &config).expect("loading should succeed");
1722
1723 let tokens = vec![0u32, 5, 42, 99];
1725 let logits = transformer.forward(&tokens);
1726
1727 assert_eq!(logits.len(), tokens.len() * config.vocab_size);
1728 let data = logits.data();
1729 let nan_count = data.iter().filter(|v| v.is_nan()).count();
1730 let inf_count = data.iter().filter(|v| v.is_infinite()).count();
1731 assert_eq!(nan_count, 0, "Loaded model output must not contain NaN");
1732 assert_eq!(inf_count, 0, "Loaded model output must not contain Inf");
1733 }
1734
1735 #[test]
1740 fn test_ssc024_from_safetensors_no_files() {
1741 let dir = TempDir::new().expect("create temp dir");
1742 let config = TransformerConfig::tiny();
1743
1744 let result = Transformer::from_safetensors(dir.path(), &config);
1745 assert!(result.is_err());
1746 let err_msg = match result {
1747 Err(e) => e.to_string(),
1748 Ok(_) => panic!("expected error"),
1749 };
1750 assert!(
1751 err_msg.contains("No SafeTensors files"),
1752 "Error should mention missing files: {err_msg}"
1753 );
1754 }
1755
1756 #[test]
1761 fn test_ssc024_from_safetensors_wrong_embedding_shape() {
1762 let dir = TempDir::new().expect("create temp dir");
1763 let config = TransformerConfig::tiny();
1764 let hidden = config.hidden_size;
1765
1766 let wrong_embed_bytes: Vec<u8> =
1768 std::iter::repeat_n(0.01_f32, 42).flat_map(f32::to_le_bytes).collect();
1769
1770 let kv_hidden = config.num_kv_heads * config.head_dim();
1775 let intermediate = config.intermediate_size;
1776
1777 let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1778
1779 let make_f32 = |n: usize, val: f32| -> Vec<u8> {
1780 std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
1781 };
1782
1783 td.push(("model.embed_tokens.weight".to_string(), wrong_embed_bytes, vec![42]));
1785 td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
1786
1787 for i in 0..config.num_hidden_layers {
1788 let p = format!("model.layers.{i}");
1789 td.push((
1790 format!("{p}.input_layernorm.weight"),
1791 make_f32(hidden, 1.0),
1792 vec![hidden],
1793 ));
1794 td.push((
1795 format!("{p}.post_attention_layernorm.weight"),
1796 make_f32(hidden, 1.0),
1797 vec![hidden],
1798 ));
1799 td.push((
1800 format!("{p}.self_attn.q_proj.weight"),
1801 make_f32(hidden * hidden, 0.01),
1802 vec![hidden, hidden],
1803 ));
1804 td.push((
1805 format!("{p}.self_attn.k_proj.weight"),
1806 make_f32(hidden * kv_hidden, 0.01),
1807 vec![kv_hidden, hidden],
1808 ));
1809 td.push((
1810 format!("{p}.self_attn.v_proj.weight"),
1811 make_f32(hidden * kv_hidden, 0.01),
1812 vec![kv_hidden, hidden],
1813 ));
1814 td.push((
1815 format!("{p}.self_attn.o_proj.weight"),
1816 make_f32(hidden * hidden, 0.01),
1817 vec![hidden, hidden],
1818 ));
1819 td.push((
1820 format!("{p}.mlp.gate_proj.weight"),
1821 make_f32(hidden * intermediate, 0.01),
1822 vec![intermediate, hidden],
1823 ));
1824 td.push((
1825 format!("{p}.mlp.up_proj.weight"),
1826 make_f32(hidden * intermediate, 0.01),
1827 vec![intermediate, hidden],
1828 ));
1829 td.push((
1830 format!("{p}.mlp.down_proj.weight"),
1831 make_f32(intermediate * hidden, 0.01),
1832 vec![hidden, intermediate],
1833 ));
1834 }
1835
1836 let views: Vec<TensorView<'_>> = td
1837 .iter()
1838 .map(|(_, bytes, shape)| {
1839 TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
1840 })
1841 .collect();
1842 let named: Vec<(&str, &TensorView<'_>)> =
1843 td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
1844
1845 let file_path = dir.path().join("model.safetensors");
1846 let serialized =
1847 serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
1848 std::fs::write(&file_path, serialized).expect("write");
1849
1850 let result = Transformer::from_safetensors(dir.path(), &config);
1851 assert!(result.is_err(), "Wrong embedding shape should fail");
1852 let err_msg = match result {
1853 Err(e) => e.to_string(),
1854 Ok(_) => panic!("expected error"),
1855 };
1856 assert!(
1857 err_msg.contains("Shape mismatch") || err_msg.contains("embed_tokens"),
1858 "Error should indicate shape issue: {err_msg}"
1859 );
1860 }
1861
1862 #[test]
1867 fn test_ssc024_from_safetensors_nan_detection() {
1868 let dir = TempDir::new().expect("create temp dir");
1869 let config = TransformerConfig::tiny();
1870 let hidden = config.hidden_size;
1871 let kv_hidden = config.num_kv_heads * config.head_dim();
1872 let intermediate = config.intermediate_size;
1873 let vocab = config.vocab_size;
1874
1875 let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1876
1877 let make_f32 = |n: usize, val: f32| -> Vec<u8> {
1878 std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
1879 };
1880
1881 let mut embed_vals: Vec<f32> = vec![0.01; vocab * hidden];
1883 embed_vals[42] = f32::NAN;
1884 let embed_bytes: Vec<u8> = embed_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1885
1886 td.push(("model.embed_tokens.weight".to_string(), embed_bytes, vec![vocab, hidden]));
1887 td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
1888
1889 for i in 0..config.num_hidden_layers {
1890 let p = format!("model.layers.{i}");
1891 td.push((
1892 format!("{p}.input_layernorm.weight"),
1893 make_f32(hidden, 1.0),
1894 vec![hidden],
1895 ));
1896 td.push((
1897 format!("{p}.post_attention_layernorm.weight"),
1898 make_f32(hidden, 1.0),
1899 vec![hidden],
1900 ));
1901 td.push((
1902 format!("{p}.self_attn.q_proj.weight"),
1903 make_f32(hidden * hidden, 0.01),
1904 vec![hidden, hidden],
1905 ));
1906 td.push((
1907 format!("{p}.self_attn.k_proj.weight"),
1908 make_f32(hidden * kv_hidden, 0.01),
1909 vec![kv_hidden, hidden],
1910 ));
1911 td.push((
1912 format!("{p}.self_attn.v_proj.weight"),
1913 make_f32(hidden * kv_hidden, 0.01),
1914 vec![kv_hidden, hidden],
1915 ));
1916 td.push((
1917 format!("{p}.self_attn.o_proj.weight"),
1918 make_f32(hidden * hidden, 0.01),
1919 vec![hidden, hidden],
1920 ));
1921 td.push((
1922 format!("{p}.mlp.gate_proj.weight"),
1923 make_f32(hidden * intermediate, 0.01),
1924 vec![intermediate, hidden],
1925 ));
1926 td.push((
1927 format!("{p}.mlp.up_proj.weight"),
1928 make_f32(hidden * intermediate, 0.01),
1929 vec![intermediate, hidden],
1930 ));
1931 td.push((
1932 format!("{p}.mlp.down_proj.weight"),
1933 make_f32(intermediate * hidden, 0.01),
1934 vec![hidden, intermediate],
1935 ));
1936 }
1937
1938 let views: Vec<TensorView<'_>> = td
1939 .iter()
1940 .map(|(_, bytes, shape)| {
1941 TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
1942 })
1943 .collect();
1944 let named: Vec<(&str, &TensorView<'_>)> =
1945 td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
1946
1947 let file_path = dir.path().join("model.safetensors");
1948 let serialized =
1949 serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
1950 std::fs::write(&file_path, serialized).expect("write");
1951
1952 let result = Transformer::from_safetensors(dir.path(), &config);
1953 assert!(result.is_err(), "NaN in weights should fail");
1954 let err_msg = match result {
1955 Err(e) => e.to_string(),
1956 Ok(_) => panic!("expected error"),
1957 };
1958 assert!(err_msg.contains("NaN"), "Error should mention NaN: {err_msg}");
1959 }
1960
1961 #[test]
1966 fn test_ssc024_from_safetensors_inf_detection() {
1967 let dir = TempDir::new().expect("create temp dir");
1968 let config = TransformerConfig::tiny();
1969 let hidden = config.hidden_size;
1970 let kv_hidden = config.num_kv_heads * config.head_dim();
1971 let intermediate = config.intermediate_size;
1972 let vocab = config.vocab_size;
1973
1974 let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1975
1976 let make_f32 = |n: usize, val: f32| -> Vec<u8> {
1977 std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
1978 };
1979
1980 let mut norm_vals: Vec<f32> = vec![1.0; hidden];
1982 norm_vals[0] = f32::INFINITY;
1983 let norm_bytes: Vec<u8> = norm_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1984
1985 td.push((
1986 "model.embed_tokens.weight".to_string(),
1987 make_f32(vocab * hidden, 0.01),
1988 vec![vocab, hidden],
1989 ));
1990 td.push(("model.norm.weight".to_string(), norm_bytes, vec![hidden]));
1991
1992 for i in 0..config.num_hidden_layers {
1993 let p = format!("model.layers.{i}");
1994 td.push((
1995 format!("{p}.input_layernorm.weight"),
1996 make_f32(hidden, 1.0),
1997 vec![hidden],
1998 ));
1999 td.push((
2000 format!("{p}.post_attention_layernorm.weight"),
2001 make_f32(hidden, 1.0),
2002 vec![hidden],
2003 ));
2004 td.push((
2005 format!("{p}.self_attn.q_proj.weight"),
2006 make_f32(hidden * hidden, 0.01),
2007 vec![hidden, hidden],
2008 ));
2009 td.push((
2010 format!("{p}.self_attn.k_proj.weight"),
2011 make_f32(hidden * kv_hidden, 0.01),
2012 vec![kv_hidden, hidden],
2013 ));
2014 td.push((
2015 format!("{p}.self_attn.v_proj.weight"),
2016 make_f32(hidden * kv_hidden, 0.01),
2017 vec![kv_hidden, hidden],
2018 ));
2019 td.push((
2020 format!("{p}.self_attn.o_proj.weight"),
2021 make_f32(hidden * hidden, 0.01),
2022 vec![hidden, hidden],
2023 ));
2024 td.push((
2025 format!("{p}.mlp.gate_proj.weight"),
2026 make_f32(hidden * intermediate, 0.01),
2027 vec![intermediate, hidden],
2028 ));
2029 td.push((
2030 format!("{p}.mlp.up_proj.weight"),
2031 make_f32(hidden * intermediate, 0.01),
2032 vec![intermediate, hidden],
2033 ));
2034 td.push((
2035 format!("{p}.mlp.down_proj.weight"),
2036 make_f32(intermediate * hidden, 0.01),
2037 vec![hidden, intermediate],
2038 ));
2039 }
2040
2041 let views: Vec<TensorView<'_>> = td
2042 .iter()
2043 .map(|(_, bytes, shape)| {
2044 TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
2045 })
2046 .collect();
2047 let named: Vec<(&str, &TensorView<'_>)> =
2048 td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
2049
2050 let file_path = dir.path().join("model.safetensors");
2051 let serialized =
2052 serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
2053 std::fs::write(&file_path, serialized).expect("write");
2054
2055 let result = Transformer::from_safetensors(dir.path(), &config);
2056 assert!(result.is_err(), "Inf in weights should fail");
2057 let err_msg = match result {
2058 Err(e) => e.to_string(),
2059 Ok(_) => panic!("expected error"),
2060 };
2061 assert!(err_msg.contains("Inf"), "Error should mention Inf: {err_msg}");
2062 }
2063
2064 #[test]
2069 fn test_ssc024_from_safetensors_missing_layer() {
2070 let dir = TempDir::new().expect("create temp dir");
2071 create_tiny_safetensors(dir.path());
2073
2074 let mut config = TransformerConfig::tiny();
2076 config.num_hidden_layers = 3;
2077
2078 let result = Transformer::from_safetensors(dir.path(), &config);
2079 assert!(result.is_err(), "Missing layer 2 should fail");
2080 let err_msg = match result {
2081 Err(e) => e.to_string(),
2082 Ok(_) => panic!("expected error"),
2083 };
2084 assert!(
2085 err_msg.contains("Missing") || err_msg.contains("layers.2"),
2086 "Error should mention missing layer: {err_msg}"
2087 );
2088 }
2089
2090 #[test]
2095 fn test_ssc024_from_safetensors_wrong_q_proj_shape() {
2096 let dir = TempDir::new().expect("create temp dir");
2097 let config = TransformerConfig::tiny();
2098 let hidden = config.hidden_size;
2099 let q_dim = config.q_dim();
2100 let kv_hidden = config.num_kv_heads * config.head_dim();
2101 let intermediate = config.intermediate_size;
2102 let vocab = config.vocab_size;
2103
2104 let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
2105
2106 let make_f32 = |n: usize, val: f32| -> Vec<u8> {
2107 std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
2108 };
2109
2110 td.push((
2111 "model.embed_tokens.weight".to_string(),
2112 make_f32(vocab * hidden, 0.01),
2113 vec![vocab, hidden],
2114 ));
2115 td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
2116
2117 for i in 0..config.num_hidden_layers {
2118 let p = format!("model.layers.{i}");
2119 td.push((
2120 format!("{p}.input_layernorm.weight"),
2121 make_f32(hidden, 1.0),
2122 vec![hidden],
2123 ));
2124 td.push((
2125 format!("{p}.post_attention_layernorm.weight"),
2126 make_f32(hidden, 1.0),
2127 vec![hidden],
2128 ));
2129
2130 if i == 0 {
2132 td.push((format!("{p}.self_attn.q_proj.weight"), make_f32(7, 0.01), vec![7]));
2133 } else {
2134 td.push((
2135 format!("{p}.self_attn.q_proj.weight"),
2136 make_f32(q_dim * hidden, 0.01),
2137 vec![q_dim, hidden],
2138 ));
2139 }
2140 td.push((
2141 format!("{p}.self_attn.k_proj.weight"),
2142 make_f32(kv_hidden * hidden, 0.01),
2143 vec![kv_hidden, hidden],
2144 ));
2145 td.push((
2146 format!("{p}.self_attn.v_proj.weight"),
2147 make_f32(kv_hidden * hidden, 0.01),
2148 vec![kv_hidden, hidden],
2149 ));
2150 td.push((
2151 format!("{p}.self_attn.o_proj.weight"),
2152 make_f32(hidden * q_dim, 0.01),
2153 vec![hidden, q_dim],
2154 ));
2155 td.push((
2156 format!("{p}.mlp.gate_proj.weight"),
2157 make_f32(hidden * intermediate, 0.01),
2158 vec![intermediate, hidden],
2159 ));
2160 td.push((
2161 format!("{p}.mlp.up_proj.weight"),
2162 make_f32(hidden * intermediate, 0.01),
2163 vec![intermediate, hidden],
2164 ));
2165 td.push((
2166 format!("{p}.mlp.down_proj.weight"),
2167 make_f32(intermediate * hidden, 0.01),
2168 vec![hidden, intermediate],
2169 ));
2170 }
2171
2172 let views: Vec<TensorView<'_>> = td
2173 .iter()
2174 .map(|(_, bytes, shape)| {
2175 TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
2176 })
2177 .collect();
2178 let named: Vec<(&str, &TensorView<'_>)> =
2179 td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
2180
2181 let file_path = dir.path().join("model.safetensors");
2182 let serialized =
2183 serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
2184 std::fs::write(&file_path, serialized).expect("write");
2185
2186 let result = Transformer::from_safetensors(dir.path(), &config);
2187 assert!(result.is_err(), "Wrong q_proj shape should fail");
2188 let err_msg = match result {
2189 Err(e) => e.to_string(),
2190 Ok(_) => panic!("expected error"),
2191 };
2192 assert!(
2193 err_msg.contains("Shape mismatch") && err_msg.contains("q_proj"),
2194 "Error should mention q_proj shape mismatch: {err_msg}"
2195 );
2196 }
2197
2198 #[test]
2203 fn test_ssc024_validate_weight_shapes_success() {
2204 let config = TransformerConfig::tiny();
2205 let hidden = config.hidden_size;
2206 let kv_hidden = config.num_kv_heads * config.head_dim();
2207 let intermediate = config.intermediate_size;
2208 let vocab = config.vocab_size;
2209
2210 let mut weights = HashMap::new();
2211 weights.insert(
2212 "model.embed_tokens.weight".to_string(),
2213 Tensor::from_vec(vec![0.1; vocab * hidden], true),
2214 );
2215 weights
2216 .insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
2217
2218 for i in 0..config.num_hidden_layers {
2219 let p = format!("model.layers.{i}");
2220 weights.insert(
2221 format!("{p}.input_layernorm.weight"),
2222 Tensor::from_vec(vec![1.0; hidden], true),
2223 );
2224 weights.insert(
2225 format!("{p}.post_attention_layernorm.weight"),
2226 Tensor::from_vec(vec![1.0; hidden], true),
2227 );
2228 weights.insert(
2229 format!("{p}.self_attn.q_proj.weight"),
2230 Tensor::from_vec(vec![0.1; hidden * hidden], true),
2231 );
2232 weights.insert(
2233 format!("{p}.self_attn.k_proj.weight"),
2234 Tensor::from_vec(vec![0.1; hidden * kv_hidden], true),
2235 );
2236 weights.insert(
2237 format!("{p}.self_attn.v_proj.weight"),
2238 Tensor::from_vec(vec![0.1; hidden * kv_hidden], true),
2239 );
2240 weights.insert(
2241 format!("{p}.self_attn.o_proj.weight"),
2242 Tensor::from_vec(vec![0.1; hidden * hidden], true),
2243 );
2244 weights.insert(
2245 format!("{p}.mlp.gate_proj.weight"),
2246 Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2247 );
2248 weights.insert(
2249 format!("{p}.mlp.up_proj.weight"),
2250 Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2251 );
2252 weights.insert(
2253 format!("{p}.mlp.down_proj.weight"),
2254 Tensor::from_vec(vec![0.1; intermediate * hidden], true),
2255 );
2256 }
2257
2258 let result = Transformer::validate_weight_shapes(&weights, &config);
2259 assert!(
2260 result.is_ok(),
2261 "Valid shapes should pass: {}",
2262 result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
2263 );
2264 }
2265
2266 #[test]
2267 fn test_ssc024_validate_weight_shapes_wrong_norm() {
2268 let config = TransformerConfig::tiny();
2269 let hidden = config.hidden_size;
2270 let vocab = config.vocab_size;
2271
2272 let mut weights = HashMap::new();
2273 weights.insert(
2274 "model.embed_tokens.weight".to_string(),
2275 Tensor::from_vec(vec![0.1; vocab * hidden], true),
2276 );
2277 weights.insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; 3], true));
2279
2280 let result = Transformer::validate_weight_shapes(&weights, &config);
2281 assert!(result.is_err());
2282 let err_msg = match result {
2283 Err(e) => e.to_string(),
2284 Ok(()) => panic!("expected error"),
2285 };
2286 assert!(err_msg.contains("model.norm.weight"));
2287 }
2288
2289 #[test]
2294 fn test_ssc024_validate_weight_values_clean() {
2295 let mut weights = HashMap::new();
2296 weights.insert("a".to_string(), Tensor::from_vec(vec![0.1, 0.2, 0.3], true));
2297 weights.insert("b".to_string(), Tensor::from_vec(vec![1.0, -1.0, 0.0], true));
2298
2299 let result = Transformer::validate_weight_values(&weights);
2300 assert!(result.is_ok());
2301 }
2302
2303 #[test]
2304 fn test_ssc024_validate_weight_values_nan() {
2305 let mut weights = HashMap::new();
2306 weights.insert("clean".to_string(), Tensor::from_vec(vec![0.1, 0.2], true));
2307 weights
2308 .insert("poisoned".to_string(), Tensor::from_vec(vec![0.1, f32::NAN, 0.3], true));
2309
2310 let result = Transformer::validate_weight_values(&weights);
2311 assert!(result.is_err());
2312 let err_msg = match result {
2313 Err(e) => e.to_string(),
2314 Ok(()) => panic!("expected error"),
2315 };
2316 assert!(err_msg.contains("NaN"));
2317 assert!(err_msg.contains("poisoned"));
2318 }
2319
2320 #[test]
2321 fn test_ssc024_validate_weight_values_inf() {
2322 let mut weights = HashMap::new();
2323 weights.insert("w".to_string(), Tensor::from_vec(vec![f32::NEG_INFINITY, 0.2], true));
2324
2325 let result = Transformer::validate_weight_values(&weights);
2326 assert!(result.is_err());
2327 let err_msg = match result {
2328 Err(e) => e.to_string(),
2329 Ok(()) => panic!("expected error"),
2330 };
2331 assert!(err_msg.contains("Inf"));
2332 }
2333
2334 #[test]
2339 fn test_gh262_qwen3_4b_weight_shapes_q_dim_ne_hidden() {
2340 let config = TransformerConfig {
2342 hidden_size: 80,
2343 num_attention_heads: 4,
2344 num_kv_heads: 2,
2345 intermediate_size: 128,
2346 num_hidden_layers: 1,
2347 vocab_size: 256,
2348 max_position_embeddings: 512,
2349 rms_norm_eps: 1e-6,
2350 rope_theta: 10000.0,
2351 use_bias: false,
2352 head_dim_override: Some(32), architecture: crate::transformer::config::ModelArchitecture::Decoder,
2354 hf_architecture: None,
2355 hf_model_type: None,
2356 tie_word_embeddings: false,
2357 };
2358
2359 let hidden = config.hidden_size; let q_dim = config.q_dim(); let kv_hidden = config.num_kv_heads * config.head_dim(); let intermediate = config.intermediate_size; let vocab = config.vocab_size; assert_ne!(q_dim, hidden, "test requires q_dim != hidden_size");
2367
2368 let mut weights = HashMap::new();
2369 weights.insert(
2370 "model.embed_tokens.weight".to_string(),
2371 Tensor::from_vec(vec![0.1; vocab * hidden], true),
2372 );
2373 weights
2374 .insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
2375
2376 let p = "model.layers.0";
2377 weights.insert(
2378 format!("{p}.input_layernorm.weight"),
2379 Tensor::from_vec(vec![1.0; hidden], true),
2380 );
2381 weights.insert(
2382 format!("{p}.post_attention_layernorm.weight"),
2383 Tensor::from_vec(vec![1.0; hidden], true),
2384 );
2385 weights.insert(
2387 format!("{p}.self_attn.q_proj.weight"),
2388 Tensor::from_vec(vec![0.1; q_dim * hidden], true),
2389 );
2390 weights.insert(
2392 format!("{p}.self_attn.k_proj.weight"),
2393 Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
2394 );
2395 weights.insert(
2397 format!("{p}.self_attn.v_proj.weight"),
2398 Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
2399 );
2400 weights.insert(
2402 format!("{p}.self_attn.o_proj.weight"),
2403 Tensor::from_vec(vec![0.1; hidden * q_dim], true),
2404 );
2405 weights.insert(
2406 format!("{p}.mlp.gate_proj.weight"),
2407 Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2408 );
2409 weights.insert(
2410 format!("{p}.mlp.up_proj.weight"),
2411 Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2412 );
2413 weights.insert(
2414 format!("{p}.mlp.down_proj.weight"),
2415 Tensor::from_vec(vec![0.1; intermediate * hidden], true),
2416 );
2417
2418 let result = Transformer::validate_weight_shapes(&weights, &config);
2420 assert!(
2421 result.is_ok(),
2422 "Qwen3-like shapes (q_dim={q_dim} != hidden={hidden}) should validate: {:?}",
2423 result.err()
2424 );
2425
2426 let model = Transformer::from_params(&config, &weights);
2428 assert!(model.is_some(), "Qwen3-like model with q_dim != hidden should construct");
2429 }
2430
2431 #[test]
2433 fn test_gh262_wrong_q_proj_size_hidden_instead_of_q_dim() {
2434 let config = TransformerConfig {
2435 hidden_size: 80,
2436 num_attention_heads: 4,
2437 num_kv_heads: 2,
2438 intermediate_size: 128,
2439 num_hidden_layers: 1,
2440 vocab_size: 256,
2441 max_position_embeddings: 512,
2442 rms_norm_eps: 1e-6,
2443 rope_theta: 10000.0,
2444 use_bias: false,
2445 head_dim_override: Some(32), architecture: crate::transformer::config::ModelArchitecture::Decoder,
2447 hf_architecture: None,
2448 hf_model_type: None,
2449 tie_word_embeddings: false,
2450 };
2451
2452 let hidden = config.hidden_size; let kv_hidden = config.num_kv_heads * config.head_dim(); let intermediate = config.intermediate_size;
2455 let vocab = config.vocab_size;
2456
2457 let mut weights = HashMap::new();
2458 weights.insert(
2459 "model.embed_tokens.weight".to_string(),
2460 Tensor::from_vec(vec![0.1; vocab * hidden], true),
2461 );
2462 weights
2463 .insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
2464
2465 let p = "model.layers.0";
2466 weights.insert(
2467 format!("{p}.input_layernorm.weight"),
2468 Tensor::from_vec(vec![1.0; hidden], true),
2469 );
2470 weights.insert(
2471 format!("{p}.post_attention_layernorm.weight"),
2472 Tensor::from_vec(vec![1.0; hidden], true),
2473 );
2474 weights.insert(
2476 format!("{p}.self_attn.q_proj.weight"),
2477 Tensor::from_vec(vec![0.1; hidden * hidden], true),
2478 );
2479 weights.insert(
2480 format!("{p}.self_attn.k_proj.weight"),
2481 Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
2482 );
2483 weights.insert(
2484 format!("{p}.self_attn.v_proj.weight"),
2485 Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
2486 );
2487 weights.insert(
2488 format!("{p}.self_attn.o_proj.weight"),
2489 Tensor::from_vec(vec![0.1; hidden * hidden], true),
2490 );
2491 weights.insert(
2492 format!("{p}.mlp.gate_proj.weight"),
2493 Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2494 );
2495 weights.insert(
2496 format!("{p}.mlp.up_proj.weight"),
2497 Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2498 );
2499 weights.insert(
2500 format!("{p}.mlp.down_proj.weight"),
2501 Tensor::from_vec(vec![0.1; intermediate * hidden], true),
2502 );
2503
2504 let result = Transformer::validate_weight_shapes(&weights, &config);
2506 assert!(result.is_err(), "hidden*hidden q_proj should fail when q_dim != hidden");
2507 let err_msg = result.err().map(|e| e.to_string()).unwrap_or_default();
2508 assert!(
2509 err_msg.contains("q_proj") && err_msg.contains("Shape mismatch"),
2510 "Error should mention q_proj shape mismatch, got: {err_msg}"
2511 );
2512 }
2513
2514 #[test]
2519 fn test_ssc024_from_safetensors_with_extra_bias_tensors() {
2520 let dir = TempDir::new().expect("create temp dir");
2524 let config = TransformerConfig::tiny();
2525 let hidden = config.hidden_size;
2526 let kv_hidden = config.num_kv_heads * config.head_dim();
2527 let intermediate = config.intermediate_size;
2528 let vocab = config.vocab_size;
2529
2530 let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
2531
2532 let make_f32 = |n: usize, val: f32| -> Vec<u8> {
2533 std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
2534 };
2535
2536 td.push((
2537 "model.embed_tokens.weight".to_string(),
2538 make_f32(vocab * hidden, 0.01),
2539 vec![vocab, hidden],
2540 ));
2541 td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
2542
2543 for i in 0..config.num_hidden_layers {
2544 let p = format!("model.layers.{i}");
2545 td.push((
2546 format!("{p}.input_layernorm.weight"),
2547 make_f32(hidden, 1.0),
2548 vec![hidden],
2549 ));
2550 td.push((
2551 format!("{p}.post_attention_layernorm.weight"),
2552 make_f32(hidden, 1.0),
2553 vec![hidden],
2554 ));
2555 td.push((
2556 format!("{p}.self_attn.q_proj.weight"),
2557 make_f32(hidden * hidden, 0.01),
2558 vec![hidden, hidden],
2559 ));
2560 td.push((
2561 format!("{p}.self_attn.k_proj.weight"),
2562 make_f32(hidden * kv_hidden, 0.01),
2563 vec![kv_hidden, hidden],
2564 ));
2565 td.push((
2566 format!("{p}.self_attn.v_proj.weight"),
2567 make_f32(hidden * kv_hidden, 0.01),
2568 vec![kv_hidden, hidden],
2569 ));
2570 td.push((
2571 format!("{p}.self_attn.o_proj.weight"),
2572 make_f32(hidden * hidden, 0.01),
2573 vec![hidden, hidden],
2574 ));
2575 td.push((
2576 format!("{p}.mlp.gate_proj.weight"),
2577 make_f32(hidden * intermediate, 0.01),
2578 vec![intermediate, hidden],
2579 ));
2580 td.push((
2581 format!("{p}.mlp.up_proj.weight"),
2582 make_f32(hidden * intermediate, 0.01),
2583 vec![intermediate, hidden],
2584 ));
2585 td.push((
2586 format!("{p}.mlp.down_proj.weight"),
2587 make_f32(intermediate * hidden, 0.01),
2588 vec![hidden, intermediate],
2589 ));
2590
2591 td.push((
2593 format!("{p}.self_attn.q_proj.bias"),
2594 make_f32(hidden, 0.0),
2595 vec![hidden],
2596 ));
2597 td.push((
2598 format!("{p}.self_attn.k_proj.bias"),
2599 make_f32(kv_hidden, 0.0),
2600 vec![kv_hidden],
2601 ));
2602 td.push((
2603 format!("{p}.self_attn.v_proj.bias"),
2604 make_f32(kv_hidden, 0.0),
2605 vec![kv_hidden],
2606 ));
2607 }
2608
2609 let views: Vec<TensorView<'_>> = td
2610 .iter()
2611 .map(|(_, bytes, shape)| {
2612 TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
2613 })
2614 .collect();
2615 let named: Vec<(&str, &TensorView<'_>)> =
2616 td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
2617
2618 let file_path = dir.path().join("model.safetensors");
2619 let serialized =
2620 serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
2621 std::fs::write(&file_path, serialized).expect("write");
2622
2623 let result = Transformer::from_safetensors(dir.path(), &config);
2625 assert!(
2626 result.is_ok(),
2627 "Extra bias tensors should not cause failure: {}",
2628 result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
2629 );
2630 }
2631 }
2632}