Skip to main content

entrenar/transformer/
model.rs

1//! Complete transformer model module
2//!
3//! This module provides the full transformer model for language modeling.
4
5use crate::autograd::matmul_nt;
6use crate::error::{Error, Result};
7use crate::Tensor;
8use provable_contracts_macros::{ensures, requires};
9use std::collections::HashMap;
10use std::path::Path;
11
12use super::block::TransformerBlock;
13use super::config::TransformerConfig;
14use super::embedding::Embedding;
15use super::norm::RMSNorm;
16use super::weights::{load_safetensors_weights, validate_weights, Architecture};
17
18/// Complete transformer model
19pub struct Transformer {
20    /// Configuration
21    pub config: TransformerConfig,
22    /// Token embedding layer
23    pub embed_tokens: Embedding,
24    /// Transformer layers
25    pub layers: Vec<TransformerBlock>,
26    /// Final layer normalization
27    pub norm: RMSNorm,
28    /// Language model head (tied to embeddings or separate)
29    pub lm_head: Option<Tensor>,
30}
31
32impl Transformer {
33    /// Create new transformer with initialized weights
34    pub fn new(config: &TransformerConfig) -> Self {
35        let layers: Vec<TransformerBlock> =
36            (0..config.num_hidden_layers).map(|i| TransformerBlock::new(config, i)).collect();
37
38        Self {
39            config: config.clone(),
40            embed_tokens: Embedding::new(config.vocab_size, config.hidden_size),
41            layers,
42            norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
43            lm_head: None, // Use tied embeddings by default
44        }
45    }
46
47    /// Create transformer from parameter map
48    ///
49    /// Expected parameter names (following HuggingFace LLaMA convention):
50    /// - `model.embed_tokens.weight`
51    /// - `model.layers.{i}.*`
52    /// - `model.norm.weight`
53    /// - `lm_head.weight` (optional, uses tied embeddings if not present)
54    pub fn from_params(
55        config: &TransformerConfig,
56        params: &HashMap<String, Tensor>,
57    ) -> Option<Self> {
58        let embed_tokens = Embedding::from_params(
59            params,
60            "model.embed_tokens.weight",
61            config.vocab_size,
62            config.hidden_size,
63        )?;
64
65        let layers: Option<Vec<TransformerBlock>> = (0..config.num_hidden_layers)
66            .map(|i| TransformerBlock::from_params(config, params, i))
67            .collect();
68        let layers = layers?;
69
70        let norm =
71            RMSNorm::from_params(params, "model.norm", config.rms_norm_eps, config.hidden_size)?;
72
73        // PMAT-329: Validate lm_head shape if present
74        let lm_head = if let Some(tensor) = params.get("lm_head.weight") {
75            let expected = config.hidden_size * config.vocab_size;
76            if tensor.len() != expected {
77                eprintln!(
78                    "[PMAT-329] lm_head.weight: shape mismatch — got {} elements, expected {expected} ({hidden}x{vocab})",
79                    tensor.len(),
80                    hidden = config.hidden_size,
81                    vocab = config.vocab_size,
82                );
83                return None;
84            }
85            Some(tensor.clone())
86        } else {
87            None
88        };
89
90        Some(Self { config: config.clone(), embed_tokens, layers, norm, lm_head })
91    }
92
93    /// Load transformer from SafeTensors file(s)
94    ///
95    /// Reads SafeTensors weights from `model_path`, converts BF16/F16 to F32,
96    /// validates shapes against `config`, checks for NaN/Inf, and constructs
97    /// the complete `Transformer`.
98    ///
99    /// # Arguments
100    ///
101    /// * `model_path` - Path to model directory or single SafeTensors file
102    /// * `config` - Transformer configuration specifying model dimensions
103    ///
104    /// # Errors
105    ///
106    /// Returns `Error::ConfigError` if:
107    /// - No SafeTensors files found
108    /// - Required weight tensors are missing
109    /// - Weight shapes do not match config dimensions
110    /// - Weights contain NaN or Inf values
111    /// - Layer count does not match config
112    pub fn from_safetensors(
113        model_path: impl AsRef<Path>,
114        config: &TransformerConfig,
115    ) -> Result<Self> {
116        let model_path = model_path.as_ref();
117
118        // Load and convert all weights from SafeTensors files
119        let weights = load_safetensors_weights(model_path, Architecture::Auto)?;
120
121        // Structural validation: all required keys present
122        validate_weights(&weights, config.num_hidden_layers)?;
123
124        // Shape validation against config dimensions
125        Self::validate_weight_shapes(&weights, config)?;
126
127        // NaN/Inf validation
128        Self::validate_weight_values(&weights)?;
129
130        // Build transformer from validated weights
131        Self::from_params(config, &weights).ok_or_else(|| {
132            Error::ConfigError(
133                "Failed to construct Transformer from loaded weights \
134                 (internal from_params returned None after validation passed)"
135                    .into(),
136            )
137        })
138    }
139
140    /// Load transformer from APR file (.apr format)
141    ///
142    /// Reads tensor data from an APR binary file, dequantizing from any stored
143    /// dtype (F16, Q4K, etc.) to F32. Uses the same validation pipeline as
144    /// `from_safetensors`: structural, shape, and NaN/Inf checks.
145    ///
146    /// # Arguments
147    /// * `apr_path` - Path to the .apr model file
148    /// * `config` - Transformer configuration (typically read from APR metadata)
149    ///
150    /// # Errors
151    /// Returns `Error::ConfigError` if tensors are missing, shapes mismatch, or
152    /// weights contain NaN/Inf values.
153    pub fn from_apr(apr_path: impl AsRef<Path>, config: &TransformerConfig) -> Result<Self> {
154        use aprender::serialization::apr::AprReader;
155
156        let apr_path = apr_path.as_ref();
157        let reader = AprReader::open(apr_path).map_err(|e| {
158            Error::ConfigError(format!("Failed to open APR file '{}': {e}", apr_path.display()))
159        })?;
160
161        // Build weight map from APR tensors — map GGUF names to HF convention (PMAT-489)
162        let is_gguf_names = reader.tensors.iter().any(|t| t.name == "token_embd.weight");
163        if is_gguf_names {
164            eprintln!(
165                "[PMAT-489] Detected GGUF tensor names in APR file, mapping to HF convention"
166            );
167        }
168        let mut weights = HashMap::new();
169        for desc in &reader.tensors {
170            let data = reader.read_tensor_as_f32(&desc.name).map_err(|e| {
171                Error::ConfigError(format!("Failed to read tensor '{}': {e}", desc.name))
172            })?;
173            let mapped_name = if is_gguf_names {
174                super::weights::mapping::map_weight_name(
175                    &desc.name,
176                    super::weights::Architecture::Gguf,
177                )
178            } else {
179                desc.name.clone()
180            };
181            weights.insert(mapped_name, Tensor::from_vec(data, false));
182        }
183
184        // Same validation pipeline as from_safetensors
185        validate_weights(&weights, config.num_hidden_layers)?;
186        Self::validate_weight_shapes(&weights, config)?;
187        Self::validate_weight_values(&weights)?;
188
189        Self::from_params(config, &weights).ok_or_else(|| {
190            Error::ConfigError(
191                "Failed to construct Transformer from APR weights \
192                 (from_params returned None after validation passed)"
193                    .into(),
194            )
195        })
196    }
197
198    /// Validate that all weight tensor shapes match the config dimensions
199    fn validate_weight_shapes(
200        weights: &HashMap<String, Tensor>,
201        config: &TransformerConfig,
202    ) -> Result<()> {
203        let hidden = config.hidden_size;
204        let q_dim = config.q_dim();
205        let kv_hidden = config.num_kv_heads * config.head_dim();
206        let intermediate = config.intermediate_size;
207        let vocab = config.vocab_size;
208
209        // Helper closure for shape checking
210        let check = |name: &str, expected: usize| -> Result<()> {
211            if let Some(tensor) = weights.get(name) {
212                if tensor.len() != expected {
213                    return Err(Error::ConfigError(format!(
214                        "Shape mismatch for '{name}': expected {expected} elements, got {}",
215                        tensor.len()
216                    )));
217                }
218            }
219            // Missing keys are caught by validate_weights
220            Ok(())
221        };
222
223        // Global weights
224        check("model.embed_tokens.weight", vocab * hidden)?;
225        check("model.norm.weight", hidden)?;
226
227        // Optional lm_head
228        if weights.contains_key("lm_head.weight") {
229            check("lm_head.weight", vocab * hidden)?;
230        }
231
232        // Per-layer weights
233        for i in 0..config.num_hidden_layers {
234            let p = format!("model.layers.{i}");
235
236            // Layer norms
237            check(&format!("{p}.input_layernorm.weight"), hidden)?;
238            check(&format!("{p}.post_attention_layernorm.weight"), hidden)?;
239
240            // Attention projections: Q/O use q_dim, K/V use kv_hidden
241            check(&format!("{p}.self_attn.q_proj.weight"), q_dim * hidden)?;
242            check(&format!("{p}.self_attn.k_proj.weight"), kv_hidden * hidden)?;
243            check(&format!("{p}.self_attn.v_proj.weight"), kv_hidden * hidden)?;
244            check(&format!("{p}.self_attn.o_proj.weight"), hidden * q_dim)?;
245
246            // Optional attention biases (Qwen2 etc.)
247            check(&format!("{p}.self_attn.q_proj.bias"), q_dim)?;
248            check(&format!("{p}.self_attn.k_proj.bias"), kv_hidden)?;
249            check(&format!("{p}.self_attn.v_proj.bias"), kv_hidden)?;
250
251            // MLP projections
252            check(&format!("{p}.mlp.gate_proj.weight"), hidden * intermediate)?;
253            check(&format!("{p}.mlp.up_proj.weight"), hidden * intermediate)?;
254            check(&format!("{p}.mlp.down_proj.weight"), intermediate * hidden)?;
255        }
256
257        Ok(())
258    }
259
260    /// Validate that no weight tensors contain NaN or Inf values
261    fn validate_weight_values(weights: &HashMap<String, Tensor>) -> Result<()> {
262        for (name, tensor) in weights {
263            let data = tensor.data();
264            for (i, &val) in data.iter().enumerate() {
265                if val.is_nan() {
266                    return Err(Error::ConfigError(format!(
267                        "NaN detected in weight '{name}' at index {i}"
268                    )));
269                }
270                if val.is_infinite() {
271                    return Err(Error::ConfigError(format!(
272                        "Inf detected in weight '{name}' at index {i}"
273                    )));
274                }
275            }
276        }
277        Ok(())
278    }
279
280    /// Forward pass for language modeling
281    ///
282    /// # Arguments
283    /// * `token_ids` - Input token IDs
284    ///
285    /// # Returns
286    /// Logits tensor (seq_len * vocab_size, flattened)
287    #[requires(!token_ids.is_empty())]
288    #[ensures(ret.len() == token_ids.len() * self.config.vocab_size)]
289    pub fn forward(&self, token_ids: &[u32]) -> Tensor {
290        contract_pre_embedding_lookup!(token_ids);
291        let seq_len = token_ids.len();
292        let hidden_size = self.config.hidden_size;
293
294        // Embed tokens
295        let mut hidden = self.embed_tokens.forward(token_ids);
296
297        // Pass through transformer layers
298        for layer in &self.layers {
299            hidden = layer.forward(&hidden, seq_len);
300        }
301
302        // Final normalization
303        let normalized = self.norm.forward_batched(&hidden, seq_len, hidden_size);
304
305        // Language model head
306        let lm_weight = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
307
308        // lm_head / tied embed_tokens is [vocab_size, hidden_size] in HF (ENT-269)
309        let result =
310            matmul_nt(&normalized, lm_weight, seq_len, hidden_size, self.config.vocab_size);
311        contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
312        result
313    }
314
315    /// Forward pass returning hidden states (before lm_head)
316    ///
317    /// # Arguments
318    /// * `token_ids` - Input token IDs
319    ///
320    /// # Returns
321    /// Hidden states tensor (seq_len * hidden_size, flattened)
322    #[requires(!token_ids.is_empty())]
323    #[ensures(ret.len() == token_ids.len() * self.config.hidden_size)]
324    pub fn forward_hidden(&self, token_ids: &[u32]) -> Tensor {
325        contract_pre_embedding_lookup!(token_ids);
326        let seq_len = token_ids.len();
327        let hidden_size = self.config.hidden_size;
328
329        // Embed tokens
330        let mut hidden = self.embed_tokens.forward(token_ids);
331
332        // Pass through transformer layers
333        for layer in &self.layers {
334            hidden = layer.forward(&hidden, seq_len);
335        }
336
337        // Final normalization
338        let result = self.norm.forward_batched(&hidden, seq_len, hidden_size);
339        contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
340        result
341    }
342
343    /// Forward pass returning hidden states with LoRA adjusts (KAIZEN-011)
344    ///
345    /// Like `forward_hidden` but applies LoRA adapters to Q/V projections in
346    /// each transformer layer's attention. Enables non-CUDA LoRA training by
347    /// putting LoRA parameters into the autograd graph.
348    ///
349    /// # Arguments
350    /// * `token_ids` - Input token IDs
351    /// * `lora_layers` - LoRA layers in [Q_0, V_0, Q_1, V_1, ...] order
352    ///
353    /// # Returns
354    /// Hidden states tensor (seq_len * hidden_size, flattened)
355    pub fn forward_hidden_with_lora(
356        &self,
357        token_ids: &[u32],
358        lora_layers: &[crate::lora::LoRALayer],
359    ) -> Tensor {
360        contract_pre_embedding_lookup!(token_ids);
361        let seq_len = token_ids.len();
362        let hidden_size = self.config.hidden_size;
363
364        let mut hidden = self.embed_tokens.forward(token_ids);
365
366        for (layer_idx, layer) in self.layers.iter().enumerate() {
367            let norm1 = layer.input_norm.forward_batched(&hidden, seq_len, hidden_size);
368
369            // KAIZEN-011: Apply LoRA to attention Q/V projections
370            let q_idx = layer_idx * 2;
371            let v_idx = layer_idx * 2 + 1;
372            let attn_out = if v_idx < lora_layers.len() {
373                layer.self_attn.forward_with_lora(
374                    &norm1,
375                    seq_len,
376                    lora_layers[q_idx].lora_a(),
377                    lora_layers[q_idx].lora_b(),
378                    lora_layers[v_idx].lora_a(),
379                    lora_layers[v_idx].lora_b(),
380                    lora_layers[q_idx].rank(),
381                    lora_layers[q_idx].scale(),
382                )
383            } else {
384                layer.self_attn.forward(&norm1, seq_len)
385            };
386
387            let residual = crate::autograd::add(&hidden, &attn_out);
388            let norm2 = layer.post_attn_norm.forward_batched(&residual, seq_len, hidden_size);
389            let ffn_out = layer.ffn.forward(&norm2, seq_len);
390            hidden = crate::autograd::add(&residual, &ffn_out);
391        }
392
393        let result = self.norm.forward_batched(&hidden, seq_len, hidden_size);
394        contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
395        result
396    }
397
398    /// Forward pass with LoRA adapters (ENT-LoRA-001)
399    ///
400    /// Like `forward` but applies LoRA adapters to Q/V projections.
401    /// Returns full logits (seq_len * vocab_size).
402    ///
403    /// # Arguments
404    /// * `token_ids` - Input token IDs
405    /// * `lora_layers` - LoRA layers in [Q_0, V_0, Q_1, V_1, ...] order
406    pub fn forward_with_lora(
407        &self,
408        token_ids: &[u32],
409        lora_layers: &[crate::lora::LoRALayer],
410    ) -> Tensor {
411        contract_pre_embedding_lookup!(token_ids);
412        let seq_len = token_ids.len();
413        let hidden_size = self.config.hidden_size;
414
415        let hidden = self.forward_hidden_with_lora(token_ids, lora_layers);
416        let lm_weight = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
417        let result = matmul_nt(&hidden, lm_weight, seq_len, hidden_size, self.config.vocab_size);
418        contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
419        result
420    }
421
422    /// Get the last token's logits (for generation)
423    pub fn forward_last(&self, token_ids: &[u32]) -> Tensor {
424        contract_pre_embedding_lookup!(token_ids);
425        let logits = self.forward(token_ids);
426        let seq_len = token_ids.len();
427        let vocab_size = self.config.vocab_size;
428
429        // Extract last position
430        let start = (seq_len - 1) * vocab_size;
431        let end = start + vocab_size;
432        let last_logits: Vec<f32> =
433            logits.data().as_slice().expect("logits must be contiguous")[start..end].to_vec();
434
435        let result = Tensor::from_vec(last_logits, logits.requires_grad());
436        contract_post_embedding_lookup!(result.data().as_slice().unwrap_or(&[]));
437        result
438    }
439
440    /// Get all parameters as a vector
441    pub fn parameters(&self) -> Vec<&Tensor> {
442        let mut params = vec![&self.embed_tokens.weight, &self.norm.weight];
443        for layer in &self.layers {
444            params.extend(layer.parameters());
445        }
446        if let Some(lm_head) = &self.lm_head {
447            params.push(lm_head);
448        }
449        params
450    }
451
452    /// Get all parameters as mutable references for optimizer
453    pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
454        let mut params: Vec<&mut Tensor> = Vec::new();
455        params.push(&mut self.embed_tokens.weight);
456        params.push(&mut self.norm.weight);
457        for layer in &mut self.layers {
458            params.extend(layer.parameters_mut());
459        }
460        if let Some(lm_head) = &mut self.lm_head {
461            params.push(lm_head);
462        }
463        params
464    }
465
466    /// Get configuration
467    pub fn config(&self) -> &TransformerConfig {
468        &self.config
469    }
470
471    /// Embed a single token, returning hidden_size floats.
472    pub fn embed_token(&self, token_id: u32) -> Vec<f32> {
473        let w = self.embed_tokens.weight.data();
474        let data = w.as_slice().expect("contiguous embedding");
475        let h = self.config.hidden_size;
476        let offset = (token_id as usize) * h;
477        data[offset..offset + h].to_vec()
478    }
479
480    /// Get the output norm weight as a slice.
481    pub fn output_norm_weight_slice(&self) -> &[f32] {
482        self.norm.weight.data().as_slice().expect("contiguous norm weight")
483    }
484
485    /// Get the lm_head weight as a slice (vocab_size × hidden_size, row-major).
486    pub fn lm_head_weight_slice(&self) -> &[f32] {
487        let w = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
488        w.data().as_slice().expect("contiguous lm_head")
489    }
490
491    /// Get the language model head weight tensor.
492    ///
493    /// Returns the dedicated `lm_head` weight if present, otherwise falls back
494    /// to tied embedding weights.
495    pub fn lm_head_weight(&self) -> &Tensor {
496        self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight)
497    }
498
499    /// Get named parameters for checkpoint serialization.
500    ///
501    /// Returns (name, tensor) pairs matching HuggingFace weight conventions.
502    /// This handles variable parameter counts (e.g., models with/without attention biases)
503    /// correctly, unlike the hardcoded 9-params-per-layer assumption.
504    pub fn named_parameters(&self) -> Vec<(String, &Tensor)> {
505        let mut params = vec![
506            ("model.embed_tokens.weight".to_string(), &self.embed_tokens.weight),
507            ("model.norm.weight".to_string(), &self.norm.weight),
508        ];
509        for layer in &self.layers {
510            params.extend(layer.named_parameters());
511        }
512        if let Some(ref lm_head) = self.lm_head {
513            params.push(("lm_head.weight".to_string(), lm_head));
514        }
515        params
516    }
517
518    /// ENT-282: Set a named parameter by name (for delta checkpoint overlay).
519    ///
520    /// Returns true if the parameter was found and set.
521    pub fn set_named_parameter(&mut self, name: &str, value: Tensor) -> bool {
522        if name == "model.embed_tokens.weight" {
523            self.embed_tokens.weight = value;
524            return true;
525        }
526        if name == "model.norm.weight" {
527            self.norm.weight = value;
528            return true;
529        }
530        if name == "lm_head.weight" {
531            self.lm_head = Some(value);
532            return true;
533        }
534        // Per-layer parameters: model.layers.{idx}.{suffix}
535        if let Some(rest) = name.strip_prefix("model.layers.") {
536            if let Some(dot_pos) = rest.find('.') {
537                if let Ok(idx) = rest[..dot_pos].parse::<usize>() {
538                    if idx < self.layers.len() {
539                        let suffix = &rest[dot_pos + 1..];
540                        return self.layers[idx].set_named_parameter(suffix, value);
541                    }
542                }
543            }
544        }
545        false
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_transformer_tiny_forward() {
555        let config = TransformerConfig::tiny();
556        let transformer = Transformer::new(&config);
557        let tokens = vec![1, 2, 3];
558        let logits = transformer.forward(&tokens);
559        assert_eq!(logits.len(), 3 * config.vocab_size);
560    }
561
562    /// FALSIFY-APR-PRETRAIN-ARCH-004 (smoke level) — GQA-7:1 forward pass
563    /// runs without panic and produces finite output of correct shape.
564    ///
565    /// Per `apr-pretrain-arch-polymorphic-v1` (PR #1473), the §49 fine-tune
566    /// path uses Qwen2.5-Coder-0.5B's GQA-7:1 ratio (kv_heads=2,
567    /// query_heads=14). The Llama370M codepath only exercised GQA-4:1.
568    /// This test pins that the existing attention kernel handles the new
569    /// ratio without per-ratio specialization.
570    ///
571    /// Tiny shape (hidden=112=14*8, head_dim=8) keeps the test under 1ms.
572    /// Full numerical-parity vs GQA-1:1 reference (cosine ≥ 0.9999) is a
573    /// FUNCTIONAL-level discharge, not algorithm-level. PARTIAL_ALGORITHM_LEVEL
574    /// only requires that the kernel COMPILES and PRODUCES finite output for
575    /// the new ratio — both proven here.
576    ///
577    /// Spec: SPEC-SHIP-TWO-001 §50.4 step 5e.
578    #[test]
579    fn falsify_apr_pretrain_arch_004_gqa_7_1_forward_pass_smoke() {
580        // Tiny GQA-7:1 config: kv_heads=2, num_attention_heads=14, head_dim=8.
581        // hidden = num_attention_heads * head_dim = 14 * 8 = 112.
582        let config = TransformerConfig {
583            hidden_size: 112,
584            num_attention_heads: 14,
585            num_kv_heads: 2,
586            intermediate_size: 64,
587            num_hidden_layers: 1,
588            vocab_size: 256,
589            max_position_embeddings: 512,
590            rms_norm_eps: 1e-6,
591            rope_theta: 1_000_000.0, // Qwen2 ROPE convention
592            use_bias: true,          // Qwen2 quirk
593            head_dim_override: None, // 112 / 14 = 8, no override needed
594            architecture: crate::transformer::config::ModelArchitecture::Decoder,
595            hf_architecture: None,
596            hf_model_type: None,
597            tie_word_embeddings: true, // Qwen2.5-0.5B convention
598        };
599
600        // Drift-prevention: verify GQA-7:1 ratio holds at construction time.
601        // If a future refactor flips num_attention_heads or num_kv_heads such
602        // that the ratio is no longer 7, this test catches it before the
603        // forward pass even runs.
604        assert_eq!(
605            config.num_attention_heads / config.num_kv_heads,
606            7,
607            "GQA-7:1 ratio must be 14/2=7 (Qwen2.5-0.5B canonical)"
608        );
609
610        let transformer = Transformer::new(&config);
611        let tokens = vec![1u32, 2, 3, 4]; // seq_len=4 short prefix
612        let logits = transformer.forward(&tokens);
613
614        // Shape invariant: forward returns seq_len * vocab_size logits.
615        assert_eq!(
616            logits.len(),
617            4 * config.vocab_size,
618            "GQA-7:1 forward must return seq_len * vocab_size logits"
619        );
620
621        // Numerical invariant: all logits finite (no NaN, no Inf).
622        // The §24 retrospective showed silent NaN propagation through GQA can
623        // produce loss=NaN that the divergence guard catches LATE (after
624        // multiple steps). FALSIFY-004's smoke level catches it at the
625        // first forward pass, before any optimizer state corrupts.
626        assert!(
627            logits.data().iter().all(|&v| v.is_finite()),
628            "GQA-7:1 forward must produce all-finite logits — silent NaN \
629             would corrupt the §49 fine-tune trajectory before FALSIFY-006 \
630             (init_loss < 6.0) could measure it"
631        );
632    }
633
634    #[test]
635    fn test_transformer_tiny_forward_last() {
636        let config = TransformerConfig::tiny();
637        let transformer = Transformer::new(&config);
638        let tokens = vec![1, 2, 3];
639        let logits = transformer.forward_last(&tokens);
640        assert_eq!(logits.len(), config.vocab_size);
641    }
642
643    #[test]
644    fn test_transformer_parameters() {
645        let config = TransformerConfig::tiny();
646        let transformer = Transformer::new(&config);
647        let params = transformer.parameters();
648        // embed_tokens + norm + (layers * (input_norm + post_attn_norm + 4 attn weights + 3 ffn weights))
649        // = 2 + 2 * (2 + 4 + 3) = 2 + 2 * 9 = 20
650        assert_eq!(params.len(), 20);
651    }
652
653    #[test]
654    fn test_transformer_config_accessor() {
655        let config = TransformerConfig::tiny();
656        let transformer = Transformer::new(&config);
657        assert_eq!(transformer.config().hidden_size, config.hidden_size);
658        assert_eq!(transformer.config().vocab_size, config.vocab_size);
659    }
660
661    #[test]
662    fn test_transformer_single_token() {
663        let config = TransformerConfig::tiny();
664        let transformer = Transformer::new(&config);
665        let tokens = vec![42];
666        let logits = transformer.forward(&tokens);
667        assert_eq!(logits.len(), config.vocab_size);
668    }
669
670    #[test]
671    fn test_output_finite_values() {
672        let config = TransformerConfig::tiny();
673        let transformer = Transformer::new(&config);
674        let tokens = vec![1, 2, 3, 4, 5];
675        let logits = transformer.forward(&tokens);
676        // All outputs should be finite (no NaN or Inf)
677        assert!(logits.data().iter().all(|&v| v.is_finite()));
678    }
679
680    #[test]
681    fn test_transformer_empty_lm_head_uses_tied_weights() {
682        let config = TransformerConfig::tiny();
683        let transformer = Transformer::new(&config);
684        // Default transformer should have no separate lm_head
685        assert!(transformer.lm_head.is_none());
686        // But should still produce valid logits
687        let tokens = vec![1, 2];
688        let logits = transformer.forward(&tokens);
689        assert_eq!(logits.len(), 2 * config.vocab_size);
690    }
691
692    #[test]
693    fn test_from_params_returns_none_on_missing() {
694        let config = TransformerConfig::tiny();
695        let params: HashMap<String, Tensor> = HashMap::new();
696        let result = Transformer::from_params(&config, &params);
697        assert!(result.is_none());
698    }
699
700    #[test]
701    fn test_transformer_from_params_with_lm_head() {
702        let config = TransformerConfig::tiny();
703        let hidden_size = config.hidden_size;
704        let vocab_size = config.vocab_size;
705        let kv_hidden_size = config.num_kv_heads * config.head_dim();
706        let intermediate_size = config.intermediate_size;
707
708        let mut params = HashMap::new();
709
710        // Embedding
711        params.insert(
712            "model.embed_tokens.weight".to_string(),
713            Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
714        );
715
716        // All layers
717        for layer_idx in 0..config.num_hidden_layers {
718            let prefix = format!("model.layers.{layer_idx}");
719            params.insert(
720                format!("{prefix}.input_layernorm.weight"),
721                Tensor::from_vec(vec![1.0; hidden_size], true),
722            );
723            params.insert(
724                format!("{prefix}.self_attn.q_proj.weight"),
725                Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
726            );
727            params.insert(
728                format!("{prefix}.self_attn.k_proj.weight"),
729                Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
730            );
731            params.insert(
732                format!("{prefix}.self_attn.v_proj.weight"),
733                Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
734            );
735            params.insert(
736                format!("{prefix}.self_attn.o_proj.weight"),
737                Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
738            );
739            params.insert(
740                format!("{prefix}.post_attention_layernorm.weight"),
741                Tensor::from_vec(vec![1.0; hidden_size], true),
742            );
743            params.insert(
744                format!("{prefix}.mlp.gate_proj.weight"),
745                Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
746            );
747            params.insert(
748                format!("{prefix}.mlp.up_proj.weight"),
749                Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
750            );
751            params.insert(
752                format!("{prefix}.mlp.down_proj.weight"),
753                Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
754            );
755        }
756
757        // Final norm
758        params.insert(
759            "model.norm.weight".to_string(),
760            Tensor::from_vec(vec![1.0; hidden_size], true),
761        );
762
763        // LM head (separate, not tied)
764        params.insert(
765            "lm_head.weight".to_string(),
766            Tensor::from_vec(vec![0.1; hidden_size * vocab_size], true),
767        );
768
769        let transformer = Transformer::from_params(&config, &params);
770        assert!(transformer.is_some());
771        let transformer = transformer.expect("operation should succeed");
772        assert!(transformer.lm_head.is_some());
773        assert_eq!(transformer.layers.len(), config.num_hidden_layers);
774    }
775
776    #[test]
777    fn test_transformer_from_params_without_lm_head() {
778        let config = TransformerConfig::tiny();
779        let hidden_size = config.hidden_size;
780        let vocab_size = config.vocab_size;
781        let kv_hidden_size = config.num_kv_heads * config.head_dim();
782        let intermediate_size = config.intermediate_size;
783
784        let mut params = HashMap::new();
785
786        // Embedding
787        params.insert(
788            "model.embed_tokens.weight".to_string(),
789            Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
790        );
791
792        // All layers
793        for layer_idx in 0..config.num_hidden_layers {
794            let prefix = format!("model.layers.{layer_idx}");
795            params.insert(
796                format!("{prefix}.input_layernorm.weight"),
797                Tensor::from_vec(vec![1.0; hidden_size], true),
798            );
799            params.insert(
800                format!("{prefix}.self_attn.q_proj.weight"),
801                Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
802            );
803            params.insert(
804                format!("{prefix}.self_attn.k_proj.weight"),
805                Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
806            );
807            params.insert(
808                format!("{prefix}.self_attn.v_proj.weight"),
809                Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
810            );
811            params.insert(
812                format!("{prefix}.self_attn.o_proj.weight"),
813                Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
814            );
815            params.insert(
816                format!("{prefix}.post_attention_layernorm.weight"),
817                Tensor::from_vec(vec![1.0; hidden_size], true),
818            );
819            params.insert(
820                format!("{prefix}.mlp.gate_proj.weight"),
821                Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
822            );
823            params.insert(
824                format!("{prefix}.mlp.up_proj.weight"),
825                Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
826            );
827            params.insert(
828                format!("{prefix}.mlp.down_proj.weight"),
829                Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
830            );
831        }
832
833        // Final norm - no lm_head
834        params.insert(
835            "model.norm.weight".to_string(),
836            Tensor::from_vec(vec![1.0; hidden_size], true),
837        );
838
839        let transformer = Transformer::from_params(&config, &params);
840        assert!(transformer.is_some());
841        let transformer = transformer.expect("operation should succeed");
842        assert!(transformer.lm_head.is_none()); // Should use tied embeddings
843    }
844
845    #[test]
846    fn test_transformer_parameters_with_lm_head() {
847        let config = TransformerConfig::tiny();
848        let mut transformer = Transformer::new(&config);
849
850        // Add a separate lm_head
851        transformer.lm_head =
852            Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
853
854        let params = transformer.parameters();
855        // embed_tokens + norm + (layers * 9) + lm_head
856        // = 1 + 1 + (2 * 9) + 1 = 21
857        assert_eq!(params.len(), 21);
858    }
859
860    #[test]
861    fn test_transformer_forward_with_lm_head() {
862        let config = TransformerConfig::tiny();
863        let mut transformer = Transformer::new(&config);
864
865        // Add a separate lm_head
866        transformer.lm_head =
867            Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
868
869        let tokens = vec![1, 2, 3];
870        let logits = transformer.forward(&tokens);
871        assert_eq!(logits.len(), 3 * config.vocab_size);
872        assert!(logits.data().iter().all(|&v| v.is_finite()));
873    }
874
875    // =========================================================================
876    // FALSIFY-L: §2.1.2 LM Head Contract — Five-Whys Gap Analysis (Refs PMAT-329)
877    //
878    // Contract: tensor-layout-v1.yaml §tensors.lm_head
879    //   critical: "true"
880    //   note: "GH-202 root cause - wrong shape caused [PAD] garbage output"
881    //
882    // Five-Whys:
883    //   Why 1: entrenar-trained model's lm_head could corrupt inference
884    //   Why 2: lm_head save/load has no shape validation
885    //   Why 3: from_params accepts ANY tensor for lm_head (like embedding)
886    //   Why 4: entrenar predates ValidatedWeight contract
887    //   Why 5: No cross-crate contract enforcement for trained models
888    //
889    // Popper (1959): "These tests attempt to falsify the claim that
890    // entrenar's lm_head handling prevents garbage output after training."
891    // =========================================================================
892
893    /// FALSIFY-L1e: from_params rejects wrong-shape lm_head (PMAT-329 fix)
894    ///
895    /// from_params now validates lm_head shape against vocab*hidden.
896    /// A tensor of 50 elements is rejected when vocab*hidden is expected.
897    #[test]
898    fn falsify_l1e_from_params_rejects_wrong_shape_lm_head() {
899        let config = TransformerConfig::tiny();
900        let hidden_size = config.hidden_size;
901        let vocab_size = config.vocab_size;
902        let kv_hidden_size = config.num_kv_heads * config.head_dim();
903        let intermediate_size = config.intermediate_size;
904
905        let mut params = HashMap::new();
906
907        // Valid embedding + layers + norm
908        params.insert(
909            "model.embed_tokens.weight".to_string(),
910            Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
911        );
912        for layer_idx in 0..config.num_hidden_layers {
913            let prefix = format!("model.layers.{layer_idx}");
914            params.insert(
915                format!("{prefix}.input_layernorm.weight"),
916                Tensor::from_vec(vec![1.0; hidden_size], true),
917            );
918            params.insert(
919                format!("{prefix}.self_attn.q_proj.weight"),
920                Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
921            );
922            params.insert(
923                format!("{prefix}.self_attn.k_proj.weight"),
924                Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
925            );
926            params.insert(
927                format!("{prefix}.self_attn.v_proj.weight"),
928                Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
929            );
930            params.insert(
931                format!("{prefix}.self_attn.o_proj.weight"),
932                Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
933            );
934            params.insert(
935                format!("{prefix}.post_attention_layernorm.weight"),
936                Tensor::from_vec(vec![1.0; hidden_size], true),
937            );
938            params.insert(
939                format!("{prefix}.mlp.gate_proj.weight"),
940                Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
941            );
942            params.insert(
943                format!("{prefix}.mlp.up_proj.weight"),
944                Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
945            );
946            params.insert(
947                format!("{prefix}.mlp.down_proj.weight"),
948                Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
949            );
950        }
951        params.insert(
952            "model.norm.weight".to_string(),
953            Tensor::from_vec(vec![1.0; hidden_size], true),
954        );
955
956        // WRONG-SHAPE lm_head: 50 elements for hidden*vocab expected
957        params.insert("lm_head.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
958
959        let transformer = Transformer::from_params(&config, &params);
960        // FIXED (PMAT-329): now rejected
961        assert!(
962            transformer.is_none(),
963            "FALSIFY-L1e: PMAT-329 fix — from_params MUST reject wrong-shape lm_head"
964        );
965    }
966
967    /// FALSIFY-L2e: Tied embeddings produce valid logit dimensions
968    ///
969    /// When lm_head is None, the embedding weight [vocab, hidden] is used as lm_head.
970    /// The matmul must produce [seq_len, vocab_size] logits.
971    #[test]
972    fn falsify_l2e_tied_embeddings_produce_correct_logit_dims() {
973        let config = TransformerConfig::tiny();
974        let transformer = Transformer::new(&config);
975        assert!(transformer.lm_head.is_none(), "Default should use tied embeddings");
976
977        let tokens = vec![1, 2, 3];
978        let logits = transformer.forward(&tokens);
979        assert_eq!(
980            logits.len(),
981            3 * config.vocab_size,
982            "FALSIFY-L2e: Tied embedding logits must be seq_len * vocab_size"
983        );
984
985        // All logits must be finite (not NaN/Inf)
986        let data = logits.data();
987        let nan_count = data.iter().filter(|v| v.is_nan()).count();
988        let inf_count = data.iter().filter(|v| v.is_infinite()).count();
989        assert_eq!(nan_count, 0, "FALSIFY-L2e: Tied logits must not contain NaN");
990        assert_eq!(inf_count, 0, "FALSIFY-L2e: Tied logits must not contain Inf");
991    }
992
993    /// FALSIFY-L3e: Separate lm_head produces valid logit dimensions
994    #[test]
995    fn falsify_l3e_separate_lm_head_produces_correct_logit_dims() {
996        let config = TransformerConfig::tiny();
997        let mut transformer = Transformer::new(&config);
998        transformer.lm_head =
999            Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
1000
1001        let tokens = vec![1, 2, 3];
1002        let logits = transformer.forward(&tokens);
1003        assert_eq!(
1004            logits.len(),
1005            3 * config.vocab_size,
1006            "FALSIFY-L3e: Separate lm_head logits must be seq_len * vocab_size"
1007        );
1008        let data = logits.data();
1009        assert!(
1010            data.iter().all(|v| v.is_finite()),
1011            "FALSIFY-L3e: Separate lm_head logits must all be finite"
1012        );
1013    }
1014
1015    /// FALSIFY-L4e: lm_head is included in parameters() and parameters_mut()
1016    ///
1017    /// If lm_head is present but not returned by parameters(), the optimizer
1018    /// won't update it during training → frozen lm_head → garbage after finetuning.
1019    #[test]
1020    fn falsify_l4e_lm_head_in_parameter_list() {
1021        let config = TransformerConfig::tiny();
1022        let mut transformer = Transformer::new(&config);
1023
1024        // Without lm_head: N params
1025        let n_without = transformer.parameters().len();
1026
1027        // With lm_head: N+1 params
1028        transformer.lm_head =
1029            Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
1030        let n_with = transformer.parameters().len();
1031        assert_eq!(
1032            n_with,
1033            n_without + 1,
1034            "FALSIFY-L4e: lm_head must be included in parameters() — optimizer needs it"
1035        );
1036
1037        // Also check parameters_mut
1038        let n_mut = transformer.parameters_mut().len();
1039        assert_eq!(
1040            n_mut, n_with,
1041            "FALSIFY-L4e: parameters_mut() must include lm_head for gradient updates"
1042        );
1043    }
1044
1045    /// FALSIFY-L5e: forward_last returns exactly vocab_size logits
1046    ///
1047    /// The last token's logits are used for next-token prediction.
1048    /// Off-by-one in the slice extraction → wrong token generated.
1049    #[test]
1050    fn falsify_l5e_forward_last_correct_size() {
1051        let config = TransformerConfig::tiny();
1052        let transformer = Transformer::new(&config);
1053
1054        let tokens = vec![1, 2, 3, 4, 5];
1055        let logits = transformer.forward_last(&tokens);
1056        assert_eq!(
1057            logits.len(),
1058            config.vocab_size,
1059            "FALSIFY-L5e: forward_last must return exactly vocab_size logits"
1060        );
1061        let data = logits.data();
1062        assert!(
1063            data.iter().all(|v| v.is_finite()),
1064            "FALSIFY-L5e: forward_last logits must all be finite"
1065        );
1066    }
1067
1068    #[test]
1069    fn test_causal_lm_loss_backward() {
1070        use crate::train::CausalLMLoss;
1071        use crate::train::LossFn;
1072
1073        let vocab_size = 100;
1074        let seq_len = 3;
1075        let loss_fn = CausalLMLoss::new(vocab_size);
1076
1077        // Create some logits
1078        let logits = Tensor::from_vec(
1079            (0..seq_len * vocab_size).map(|i| (i as f32 * 0.01).sin()).collect(),
1080            true,
1081        );
1082
1083        // Target token IDs
1084        let targets = Tensor::from_vec(vec![5.0, 10.0, 15.0], false);
1085
1086        let mut loss = loss_fn.forward(&logits, &targets);
1087
1088        // Backward
1089        crate::autograd::backward(&mut loss, None);
1090
1091        // Loss should be positive
1092        assert!(loss.data()[0] > 0.0);
1093        assert!(loss.data()[0].is_finite());
1094
1095        // Logits should have gradient
1096        assert!(logits.grad().is_some());
1097        let grad = logits.grad().expect("gradient should be available");
1098        assert!(grad.iter().all(|&v| v.is_finite()));
1099    }
1100
1101    // =========================================================================
1102    // FALSIFY-EMB-003 / FALSIFY-TE-001..004: Tied Embeddings Contract
1103    //
1104    // Five-Whys (PMAT-354):
1105    //   Why 1: entrenar had L-series lm_head tests but no EMB-003/TE-* tagged tests
1106    //   Why 2: L-series validates shape, not tied-weight CONTRACT claims
1107    //   Why 3: no mapping from tied-embeddings-v1.yaml to entrenar test names
1108    //   Why 4: entrenar predates the provable-contracts YAML
1109    //   Why 5: tied weights were assumed correct because code path is "just fallback"
1110    //
1111    // References:
1112    //   - provable-contracts/contracts/embedding-algebra-v1.yaml (EMB-003)
1113    //   - provable-contracts/contracts/tied-embeddings-v1.yaml (TE-001..004)
1114    //   - Press & Wolf (2017) "Using the Output Embedding to Improve Language Models"
1115    // =========================================================================
1116
1117    /// FALSIFY-EMB-003: Tied weight sharing — lm_head uses embed_tokens.weight
1118    ///
1119    /// Contract: when lm_head is None, forward() uses embed_tokens.weight directly
1120    /// (pointer/identity sharing, not a copy)
1121    #[test]
1122    fn falsify_emb_003_tied_weight_sharing() {
1123        let config = TransformerConfig::tiny();
1124        let transformer = Transformer::new(&config);
1125
1126        // Default: lm_head is None → tied
1127        assert!(transformer.lm_head.is_none());
1128
1129        // The weight used for lm_head projection IS embed_tokens.weight
1130        let lm_weight = transformer.lm_head.as_ref().unwrap_or(&transformer.embed_tokens.weight);
1131        let embed_weight = &transformer.embed_tokens.weight;
1132
1133        // They must be the same Tensor (same data pointer, not just equal values)
1134        assert!(
1135            std::ptr::eq(lm_weight, embed_weight),
1136            "FALSIFIED EMB-003: tied lm_head must be same object as embed_tokens.weight"
1137        );
1138    }
1139
1140    /// FALSIFY-TE-001: Output shape = (seq_len, vocab_size)
1141    #[test]
1142    fn falsify_te_001_output_shape() {
1143        let config = TransformerConfig::tiny();
1144        let transformer = Transformer::new(&config);
1145
1146        for seq_len in [1, 3, 10] {
1147            let tokens: Vec<u32> = (0..seq_len).collect();
1148            let logits = transformer.forward(&tokens);
1149            assert_eq!(
1150                logits.len(),
1151                seq_len as usize * config.vocab_size,
1152                "FALSIFIED TE-001: output shape for seq_len={seq_len}"
1153            );
1154        }
1155    }
1156
1157    /// FALSIFY-TE-002: Tied equivalence — tied output == explicit matmul with cloned W
1158    ///
1159    /// Contract: forward() with tied lm_head must produce bit-identical output
1160    /// to manually computing matmul(hidden, W_embed) with a separate copy of the
1161    /// embedding weight matrix. If they diverge, the tied path silently aliases
1162    /// or transposes incorrectly.
1163    #[test]
1164    fn falsify_te_002_tied_equivalence() {
1165        let config = TransformerConfig::tiny();
1166        let transformer = Transformer::new(&config);
1167
1168        // Tied path: forward() uses embed_tokens.weight as lm_head
1169        let tokens = vec![0u32, 3, 7, 15, 42];
1170        let tied_logits = transformer.forward(&tokens);
1171
1172        // Explicit path: clone embed weight, run hidden states, matmul manually
1173        let hidden = transformer.forward_hidden(&tokens);
1174        let w_clone = transformer.embed_tokens.weight.clone();
1175        let explicit_logits =
1176            matmul_nt(&hidden, &w_clone, tokens.len(), config.hidden_size, config.vocab_size);
1177
1178        let tied_data = tied_logits.data();
1179        let explicit_data = explicit_logits.data();
1180
1181        assert_eq!(
1182            tied_data.len(),
1183            explicit_data.len(),
1184            "FALSIFIED TE-002: output lengths differ: {} vs {}",
1185            tied_data.len(),
1186            explicit_data.len()
1187        );
1188
1189        for (i, (&t, &e)) in tied_data.iter().zip(explicit_data.iter()).enumerate() {
1190            assert!(
1191                (t - e).abs() < 1e-6,
1192                "FALSIFIED TE-002: tied[{i}] = {t} != explicit[{i}] = {e}"
1193            );
1194        }
1195    }
1196
1197    /// FALSIFY-TE-003: No extra parameters for tied embeddings
1198    ///
1199    /// Contract: tied model has exactly N params, untied has N+1 (the separate lm_head)
1200    #[test]
1201    fn falsify_te_003_no_extra_params() {
1202        let config = TransformerConfig::tiny();
1203        let tied = Transformer::new(&config);
1204        let tied_count = tied.parameters().len();
1205
1206        let mut untied = Transformer::new(&config);
1207        untied.lm_head =
1208            Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
1209        let untied_count = untied.parameters().len();
1210
1211        assert_eq!(
1212            untied_count,
1213            tied_count + 1,
1214            "FALSIFIED TE-003: tied model must have exactly 1 fewer param than untied"
1215        );
1216    }
1217
1218    /// FALSIFY-TE-004: Finite output for tied embeddings
1219    #[test]
1220    fn falsify_te_004_finite_output() {
1221        let config = TransformerConfig::tiny();
1222        let transformer = Transformer::new(&config);
1223        let tokens = vec![0u32, 5, 10, 50, 99];
1224        let logits = transformer.forward(&tokens);
1225        let data = logits.data();
1226
1227        let nan_count = data.iter().filter(|v| v.is_nan()).count();
1228        let inf_count = data.iter().filter(|v| v.is_infinite()).count();
1229
1230        assert_eq!(
1231            nan_count, 0,
1232            "FALSIFIED TE-004: tied embedding output contains {nan_count} NaN values"
1233        );
1234        assert_eq!(
1235            inf_count, 0,
1236            "FALSIFIED TE-004: tied embedding output contains {inf_count} Inf values"
1237        );
1238    }
1239
1240    // =========================================================================
1241    // PROPTEST FALSIFY-TE: Tied embeddings property-based falsification
1242    //
1243    // Five-Whys (PMAT-354, Phase 9):
1244    //   Why 1: YAML tied-embeddings-v1.yaml calls for "proptest with seq_len in [1,128]"
1245    //   Why 2: All 4 TE tests use fixed token sequences
1246    //   Why 3: TE proptest had ZERO coverage across the entire stack
1247    //   Why 4: Transformer construction is expensive, discouraging property testing
1248    //   Why 5: Fixed tokens miss edge cases in arbitrary token→logit paths
1249    //
1250    // References:
1251    //   - tied-embeddings-v1.yaml FALSIFY-TE-001: "proptest with seq_len in [1,128]"
1252    //   - tied-embeddings-v1.yaml FALSIFY-TE-002: "clone W_embed, compare tied vs explicit"
1253    //   - tied-embeddings-v1.yaml FALSIFY-TE-004: "proptest with finite x, check is_finite()"
1254    // =========================================================================
1255
1256    mod te_proptest_falsify {
1257        use super::*;
1258        use proptest::prelude::*;
1259
1260        // TE-001-prop: Output shape for random seq_len
1261        // Construct transformer once per test run, vary only the token sequence
1262        proptest! {
1263            #![proptest_config(ProptestConfig::with_cases(50))]
1264            #[test]
1265            fn falsify_te_001_prop_output_shape(
1266                seq_len in 1_usize..32,
1267            ) {
1268                let config = TransformerConfig::tiny();
1269                let transformer = Transformer::new(&config);
1270                let tokens: Vec<u32> = (0..seq_len).map(|i| (i % config.vocab_size) as u32).collect();
1271                let logits = transformer.forward(&tokens);
1272                prop_assert_eq!(
1273                    logits.len(),
1274                    seq_len * config.vocab_size,
1275                    "FALSIFIED TE-001-prop: seq_len={}, got len={}", seq_len, logits.len()
1276                );
1277            }
1278        }
1279
1280        // TE-002-prop: Tied equivalence for random tokens
1281        proptest! {
1282            #![proptest_config(ProptestConfig::with_cases(20))]
1283            #[test]
1284            fn falsify_te_002_prop_tied_equivalence(
1285                token_ids in proptest::collection::vec(0_u32..999, 1..8),
1286            ) {
1287                let config = TransformerConfig::tiny();
1288                let transformer = Transformer::new(&config);
1289
1290                let tied_logits = transformer.forward(&token_ids);
1291                let hidden = transformer.forward_hidden(&token_ids);
1292                let w_clone = transformer.embed_tokens.weight.clone();
1293                let explicit_logits = matmul_nt(
1294                    &hidden, &w_clone,
1295                    token_ids.len(), config.hidden_size, config.vocab_size,
1296                );
1297
1298                let tied_data = tied_logits.data();
1299                let explicit_data = explicit_logits.data();
1300                prop_assert_eq!(tied_data.len(), explicit_data.len());
1301
1302                for (i, (&t, &e)) in tied_data.iter().zip(explicit_data.iter()).enumerate() {
1303                    prop_assert!(
1304                        (t - e).abs() < 1e-5,
1305                        "FALSIFIED TE-002-prop: tied[{}]={} != explicit[{}]={}",
1306                        i, t, i, e
1307                    );
1308                }
1309            }
1310        }
1311
1312        // TE-004-prop: All outputs finite for random tokens
1313        proptest! {
1314            #![proptest_config(ProptestConfig::with_cases(30))]
1315            #[test]
1316            fn falsify_te_004_prop_finite(
1317                token_ids in proptest::collection::vec(0_u32..999, 1..16),
1318            ) {
1319                let config = TransformerConfig::tiny();
1320                let transformer = Transformer::new(&config);
1321                let logits = transformer.forward(&token_ids);
1322                let data = logits.data();
1323
1324                for (i, &v) in data.iter().enumerate() {
1325                    prop_assert!(
1326                        v.is_finite(),
1327                        "FALSIFIED TE-004-prop: logits[{}]={} non-finite (n_tokens={})",
1328                        i, v, token_ids.len()
1329                    );
1330                }
1331            }
1332        }
1333    }
1334
1335    // =========================================================================
1336    // FALSIFY-PIPE-001: Cross-contract pipeline test
1337    //
1338    // Five-Whys (PMAT-354, Phase 8):
1339    //   Why 1: no test exercises the full §2.1.1 pipeline as a single chain
1340    //   Why 2: EM, TE, SM tests each validate one contract in isolation
1341    //   Why 3: bugs can hide at contract boundaries (shape mismatch between stages)
1342    //   Why 4: the embed→tied_lm_head→softmax chain is the critical inference path
1343    //   Why 5: cross-contract pipeline faults would only show in integration
1344    //
1345    // Pipeline: embed(token_ids) → transformer_layers → norm → tied_matmul → softmax
1346    // Claims verified:
1347    //   EM-001: embed output shape = (seq_len, d_model)
1348    //   TE-001: tied logits shape = (seq_len, vocab_size)
1349    //   SM-001: softmax(logits) sums to 1.0 per row
1350    //   SM-002: all probabilities positive
1351    //   SM-003: argmax preserved through softmax
1352    // =========================================================================
1353
1354    /// FALSIFY-PIPE-001: Full embed → tied_lm_head → softmax pipeline
1355    #[test]
1356    fn falsify_pipe_001_embed_tied_softmax_pipeline() {
1357        let config = TransformerConfig::tiny();
1358        let transformer = Transformer::new(&config);
1359
1360        let tokens = vec![0u32, 3, 7, 15, 42];
1361        let seq_len = tokens.len();
1362        let vocab_size = config.vocab_size;
1363
1364        // Stage 1: Full forward pass (embed → layers → norm → tied matmul)
1365        let logits = transformer.forward(&tokens);
1366        let logits_data = logits.data();
1367
1368        // TE-001: logits shape = (seq_len, vocab_size)
1369        assert_eq!(
1370            logits_data.len(),
1371            seq_len * vocab_size,
1372            "FALSIFIED PIPE-001/TE-001: logits len={} != seq_len({seq_len}) * vocab({vocab_size})",
1373            logits_data.len()
1374        );
1375
1376        // TE-004: all logits finite
1377        for (i, &l) in logits_data.iter().enumerate() {
1378            assert!(l.is_finite(), "FALSIFIED PIPE-001/TE-004: logits[{i}] = {l} not finite");
1379        }
1380
1381        // Stage 2: Apply softmax per row (the sampling step)
1382        let logits_slice = logits_data.as_slice().expect("operation should succeed");
1383        for row in 0..seq_len {
1384            let start = row * vocab_size;
1385            let end = start + vocab_size;
1386            let row_logits = &logits_slice[start..end];
1387
1388            // Compute softmax for this row
1389            let max_val = row_logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1390            let exps: Vec<f32> = row_logits.iter().map(|&x| (x - max_val).exp()).collect();
1391            let sum: f32 = exps.iter().sum();
1392            let probs: Vec<f32> = exps.iter().map(|&e| e / sum).collect();
1393
1394            // SM-001: sums to 1.0
1395            let prob_sum: f32 = probs.iter().sum();
1396            assert!(
1397                (prob_sum - 1.0).abs() < 1e-4,
1398                "FALSIFIED PIPE-001/SM-001: row {row} prob sum={prob_sum}"
1399            );
1400
1401            // SM-002: all positive
1402            for (i, &p) in probs.iter().enumerate() {
1403                assert!(p >= 0.0, "FALSIFIED PIPE-001/SM-002: row {row} prob[{i}]={p} negative");
1404            }
1405
1406            // SM-003: argmax preserved
1407            let logit_argmax = row_logits
1408                .iter()
1409                .enumerate()
1410                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
1411                .expect("operation should succeed")
1412                .0;
1413            let prob_argmax = probs
1414                .iter()
1415                .enumerate()
1416                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
1417                .expect("operation should succeed")
1418                .0;
1419            assert_eq!(
1420                logit_argmax, prob_argmax,
1421                "FALSIFIED PIPE-001/SM-003: row {row} argmax changed {logit_argmax} → {prob_argmax}"
1422            );
1423        }
1424    }
1425
1426    // =========================================================================
1427    // SSC-024: Transformer::from_safetensors() tests
1428    //
1429    // Tests for loading pretrained weights from SafeTensors files.
1430    // Uses synthetic SafeTensors with the tiny config to avoid needing
1431    // real 500MB model files in CI.
1432    // =========================================================================
1433
1434    mod safetensors_tests {
1435        use super::*;
1436        use safetensors::serialize;
1437        use safetensors::tensor::{Dtype, TensorView};
1438        use tempfile::TempDir;
1439
1440        /// Helper: create a synthetic SafeTensors file with all weights
1441        /// matching the tiny config (hidden=64, 2 layers, vocab=1000).
1442        fn create_tiny_safetensors(dir: &std::path::Path) -> std::path::PathBuf {
1443            let config = TransformerConfig::tiny();
1444            let hidden = config.hidden_size;
1445            let kv_hidden = config.num_kv_heads * config.head_dim();
1446            let intermediate = config.intermediate_size;
1447            let vocab = config.vocab_size;
1448
1449            let mut tensors_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1450
1451            // Helper to create f32 bytes
1452            let make_f32 = |n: usize, val: f32| -> Vec<u8> {
1453                std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
1454            };
1455
1456            // Embedding
1457            tensors_data.push((
1458                "model.embed_tokens.weight".to_string(),
1459                make_f32(vocab * hidden, 0.01),
1460                vec![vocab, hidden],
1461            ));
1462
1463            // Final norm
1464            tensors_data.push((
1465                "model.norm.weight".to_string(),
1466                make_f32(hidden, 1.0),
1467                vec![hidden],
1468            ));
1469
1470            // Per-layer weights
1471            for i in 0..config.num_hidden_layers {
1472                let p = format!("model.layers.{i}");
1473
1474                // Layer norms
1475                tensors_data.push((
1476                    format!("{p}.input_layernorm.weight"),
1477                    make_f32(hidden, 1.0),
1478                    vec![hidden],
1479                ));
1480                tensors_data.push((
1481                    format!("{p}.post_attention_layernorm.weight"),
1482                    make_f32(hidden, 1.0),
1483                    vec![hidden],
1484                ));
1485
1486                // Attention projections
1487                tensors_data.push((
1488                    format!("{p}.self_attn.q_proj.weight"),
1489                    make_f32(hidden * hidden, 0.01),
1490                    vec![hidden, hidden],
1491                ));
1492                tensors_data.push((
1493                    format!("{p}.self_attn.k_proj.weight"),
1494                    make_f32(hidden * kv_hidden, 0.01),
1495                    vec![kv_hidden, hidden],
1496                ));
1497                tensors_data.push((
1498                    format!("{p}.self_attn.v_proj.weight"),
1499                    make_f32(hidden * kv_hidden, 0.01),
1500                    vec![kv_hidden, hidden],
1501                ));
1502                tensors_data.push((
1503                    format!("{p}.self_attn.o_proj.weight"),
1504                    make_f32(hidden * hidden, 0.01),
1505                    vec![hidden, hidden],
1506                ));
1507
1508                // MLP projections
1509                tensors_data.push((
1510                    format!("{p}.mlp.gate_proj.weight"),
1511                    make_f32(hidden * intermediate, 0.01),
1512                    vec![intermediate, hidden],
1513                ));
1514                tensors_data.push((
1515                    format!("{p}.mlp.up_proj.weight"),
1516                    make_f32(hidden * intermediate, 0.01),
1517                    vec![intermediate, hidden],
1518                ));
1519                tensors_data.push((
1520                    format!("{p}.mlp.down_proj.weight"),
1521                    make_f32(intermediate * hidden, 0.01),
1522                    vec![hidden, intermediate],
1523                ));
1524            }
1525
1526            // Build TensorViews from owned data and serialize
1527            let views: Vec<TensorView<'_>> = tensors_data
1528                .iter()
1529                .map(|(_, bytes, shape)| {
1530                    TensorView::new(Dtype::F32, shape.clone(), bytes).expect("valid tensor view")
1531                })
1532                .collect();
1533
1534            let named_views: Vec<(&str, &TensorView<'_>)> = tensors_data
1535                .iter()
1536                .zip(views.iter())
1537                .map(|((name, _, _), view)| (name.as_str(), view))
1538                .collect();
1539
1540            let file_path = dir.join("model.safetensors");
1541            let serialized =
1542                serialize(named_views, None::<std::collections::HashMap<String, String>>)
1543                    .expect("serialize safetensors");
1544            std::fs::write(&file_path, serialized).expect("write safetensors file");
1545            file_path
1546        }
1547
1548        /// Helper: create a SafeTensors file with bf16 weights (like real HF models)
1549        fn create_tiny_bf16_safetensors(dir: &std::path::Path) -> std::path::PathBuf {
1550            let config = TransformerConfig::tiny();
1551            let hidden = config.hidden_size;
1552            let kv_hidden = config.num_kv_heads * config.head_dim();
1553            let intermediate = config.intermediate_size;
1554            let vocab = config.vocab_size;
1555
1556            let mut tensors_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1557
1558            // Helper to create bf16 bytes
1559            let make_bf16 = |n: usize, val: f32| -> Vec<u8> {
1560                std::iter::repeat_n(half::bf16::from_f32(val), n)
1561                    .flat_map(half::bf16::to_le_bytes)
1562                    .collect()
1563            };
1564
1565            // Embedding
1566            tensors_data.push((
1567                "model.embed_tokens.weight".to_string(),
1568                make_bf16(vocab * hidden, 0.01),
1569                vec![vocab, hidden],
1570            ));
1571
1572            // Final norm
1573            tensors_data.push((
1574                "model.norm.weight".to_string(),
1575                make_bf16(hidden, 1.0),
1576                vec![hidden],
1577            ));
1578
1579            // Per-layer weights
1580            for i in 0..config.num_hidden_layers {
1581                let p = format!("model.layers.{i}");
1582
1583                tensors_data.push((
1584                    format!("{p}.input_layernorm.weight"),
1585                    make_bf16(hidden, 1.0),
1586                    vec![hidden],
1587                ));
1588                tensors_data.push((
1589                    format!("{p}.post_attention_layernorm.weight"),
1590                    make_bf16(hidden, 1.0),
1591                    vec![hidden],
1592                ));
1593                tensors_data.push((
1594                    format!("{p}.self_attn.q_proj.weight"),
1595                    make_bf16(hidden * hidden, 0.01),
1596                    vec![hidden, hidden],
1597                ));
1598                tensors_data.push((
1599                    format!("{p}.self_attn.k_proj.weight"),
1600                    make_bf16(hidden * kv_hidden, 0.01),
1601                    vec![kv_hidden, hidden],
1602                ));
1603                tensors_data.push((
1604                    format!("{p}.self_attn.v_proj.weight"),
1605                    make_bf16(hidden * kv_hidden, 0.01),
1606                    vec![kv_hidden, hidden],
1607                ));
1608                tensors_data.push((
1609                    format!("{p}.self_attn.o_proj.weight"),
1610                    make_bf16(hidden * hidden, 0.01),
1611                    vec![hidden, hidden],
1612                ));
1613                tensors_data.push((
1614                    format!("{p}.mlp.gate_proj.weight"),
1615                    make_bf16(hidden * intermediate, 0.01),
1616                    vec![intermediate, hidden],
1617                ));
1618                tensors_data.push((
1619                    format!("{p}.mlp.up_proj.weight"),
1620                    make_bf16(hidden * intermediate, 0.01),
1621                    vec![intermediate, hidden],
1622                ));
1623                tensors_data.push((
1624                    format!("{p}.mlp.down_proj.weight"),
1625                    make_bf16(intermediate * hidden, 0.01),
1626                    vec![hidden, intermediate],
1627                ));
1628            }
1629
1630            let views: Vec<TensorView<'_>> = tensors_data
1631                .iter()
1632                .map(|(_, bytes, shape)| {
1633                    TensorView::new(Dtype::BF16, shape.clone(), bytes).expect("valid tensor view")
1634                })
1635                .collect();
1636
1637            let named_views: Vec<(&str, &TensorView<'_>)> = tensors_data
1638                .iter()
1639                .zip(views.iter())
1640                .map(|((name, _, _), view)| (name.as_str(), view))
1641                .collect();
1642
1643            let file_path = dir.join("model.safetensors");
1644            let serialized =
1645                serialize(named_views, None::<std::collections::HashMap<String, String>>)
1646                    .expect("serialize safetensors");
1647            std::fs::write(&file_path, serialized).expect("write safetensors file");
1648            file_path
1649        }
1650
1651        // -----------------------------------------------------------------
1652        // Happy path tests
1653        // -----------------------------------------------------------------
1654
1655        #[test]
1656        fn test_ssc024_from_safetensors_f32_success() {
1657            let dir = TempDir::new().expect("create temp dir");
1658            create_tiny_safetensors(dir.path());
1659            let config = TransformerConfig::tiny();
1660
1661            let result = Transformer::from_safetensors(dir.path(), &config);
1662            assert!(
1663                result.is_ok(),
1664                "from_safetensors should succeed: {}",
1665                result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
1666            );
1667
1668            let transformer = result.expect("validated above");
1669            assert_eq!(transformer.layers.len(), config.num_hidden_layers);
1670            assert!(transformer.lm_head.is_none()); // tiny config has no lm_head
1671        }
1672
1673        #[test]
1674        fn test_ssc024_from_safetensors_bf16_conversion() {
1675            let dir = TempDir::new().expect("create temp dir");
1676            create_tiny_bf16_safetensors(dir.path());
1677            let config = TransformerConfig::tiny();
1678
1679            let result = Transformer::from_safetensors(dir.path(), &config);
1680            assert!(
1681                result.is_ok(),
1682                "BF16 loading should succeed: {}",
1683                result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
1684            );
1685
1686            let transformer = result.expect("validated above");
1687            assert_eq!(transformer.layers.len(), config.num_hidden_layers);
1688
1689            // Verify forward pass produces finite output
1690            let tokens = vec![1u32, 2, 3];
1691            let logits = transformer.forward(&tokens);
1692            assert_eq!(logits.len(), 3 * config.vocab_size);
1693            assert!(
1694                logits.data().iter().all(|v| v.is_finite()),
1695                "BF16-loaded model should produce finite outputs"
1696            );
1697        }
1698
1699        #[test]
1700        fn test_ssc024_from_safetensors_single_file_path() {
1701            let dir = TempDir::new().expect("create temp dir");
1702            let file_path = create_tiny_safetensors(dir.path());
1703            let config = TransformerConfig::tiny();
1704
1705            // Pass the file path directly, not the directory
1706            let result = Transformer::from_safetensors(&file_path, &config);
1707            assert!(
1708                result.is_ok(),
1709                "Direct file path should work: {}",
1710                result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
1711            );
1712        }
1713
1714        #[test]
1715        fn test_ssc024_loaded_model_forward_produces_finite() {
1716            let dir = TempDir::new().expect("create temp dir");
1717            create_tiny_safetensors(dir.path());
1718            let config = TransformerConfig::tiny();
1719
1720            let transformer =
1721                Transformer::from_safetensors(dir.path(), &config).expect("loading should succeed");
1722
1723            // Run forward pass
1724            let tokens = vec![0u32, 5, 42, 99];
1725            let logits = transformer.forward(&tokens);
1726
1727            assert_eq!(logits.len(), tokens.len() * config.vocab_size);
1728            let data = logits.data();
1729            let nan_count = data.iter().filter(|v| v.is_nan()).count();
1730            let inf_count = data.iter().filter(|v| v.is_infinite()).count();
1731            assert_eq!(nan_count, 0, "Loaded model output must not contain NaN");
1732            assert_eq!(inf_count, 0, "Loaded model output must not contain Inf");
1733        }
1734
1735        // -----------------------------------------------------------------
1736        // Error case: no SafeTensors files
1737        // -----------------------------------------------------------------
1738
1739        #[test]
1740        fn test_ssc024_from_safetensors_no_files() {
1741            let dir = TempDir::new().expect("create temp dir");
1742            let config = TransformerConfig::tiny();
1743
1744            let result = Transformer::from_safetensors(dir.path(), &config);
1745            assert!(result.is_err());
1746            let err_msg = match result {
1747                Err(e) => e.to_string(),
1748                Ok(_) => panic!("expected error"),
1749            };
1750            assert!(
1751                err_msg.contains("No SafeTensors files"),
1752                "Error should mention missing files: {err_msg}"
1753            );
1754        }
1755
1756        // -----------------------------------------------------------------
1757        // Error case: shape mismatch
1758        // -----------------------------------------------------------------
1759
1760        #[test]
1761        fn test_ssc024_from_safetensors_wrong_embedding_shape() {
1762            let dir = TempDir::new().expect("create temp dir");
1763            let config = TransformerConfig::tiny();
1764            let hidden = config.hidden_size;
1765
1766            // Create a file with wrong embedding shape
1767            let wrong_embed_bytes: Vec<u8> =
1768                std::iter::repeat_n(0.01_f32, 42).flat_map(f32::to_le_bytes).collect();
1769
1770            // We need at least embedding + norm + 2 layers to pass validate_weights.
1771            // But the embedding shape is wrong, so validate_weight_shapes should catch it.
1772            // Actually, we need ALL required keys for validate_weights to pass first.
1773            // Let's create a full set but with wrong embedding size.
1774            let kv_hidden = config.num_kv_heads * config.head_dim();
1775            let intermediate = config.intermediate_size;
1776
1777            let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1778
1779            let make_f32 = |n: usize, val: f32| -> Vec<u8> {
1780                std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
1781            };
1782
1783            // WRONG: embedding has 42 elements instead of vocab * hidden
1784            td.push(("model.embed_tokens.weight".to_string(), wrong_embed_bytes, vec![42]));
1785            td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
1786
1787            for i in 0..config.num_hidden_layers {
1788                let p = format!("model.layers.{i}");
1789                td.push((
1790                    format!("{p}.input_layernorm.weight"),
1791                    make_f32(hidden, 1.0),
1792                    vec![hidden],
1793                ));
1794                td.push((
1795                    format!("{p}.post_attention_layernorm.weight"),
1796                    make_f32(hidden, 1.0),
1797                    vec![hidden],
1798                ));
1799                td.push((
1800                    format!("{p}.self_attn.q_proj.weight"),
1801                    make_f32(hidden * hidden, 0.01),
1802                    vec![hidden, hidden],
1803                ));
1804                td.push((
1805                    format!("{p}.self_attn.k_proj.weight"),
1806                    make_f32(hidden * kv_hidden, 0.01),
1807                    vec![kv_hidden, hidden],
1808                ));
1809                td.push((
1810                    format!("{p}.self_attn.v_proj.weight"),
1811                    make_f32(hidden * kv_hidden, 0.01),
1812                    vec![kv_hidden, hidden],
1813                ));
1814                td.push((
1815                    format!("{p}.self_attn.o_proj.weight"),
1816                    make_f32(hidden * hidden, 0.01),
1817                    vec![hidden, hidden],
1818                ));
1819                td.push((
1820                    format!("{p}.mlp.gate_proj.weight"),
1821                    make_f32(hidden * intermediate, 0.01),
1822                    vec![intermediate, hidden],
1823                ));
1824                td.push((
1825                    format!("{p}.mlp.up_proj.weight"),
1826                    make_f32(hidden * intermediate, 0.01),
1827                    vec![intermediate, hidden],
1828                ));
1829                td.push((
1830                    format!("{p}.mlp.down_proj.weight"),
1831                    make_f32(intermediate * hidden, 0.01),
1832                    vec![hidden, intermediate],
1833                ));
1834            }
1835
1836            let views: Vec<TensorView<'_>> = td
1837                .iter()
1838                .map(|(_, bytes, shape)| {
1839                    TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
1840                })
1841                .collect();
1842            let named: Vec<(&str, &TensorView<'_>)> =
1843                td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
1844
1845            let file_path = dir.path().join("model.safetensors");
1846            let serialized =
1847                serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
1848            std::fs::write(&file_path, serialized).expect("write");
1849
1850            let result = Transformer::from_safetensors(dir.path(), &config);
1851            assert!(result.is_err(), "Wrong embedding shape should fail");
1852            let err_msg = match result {
1853                Err(e) => e.to_string(),
1854                Ok(_) => panic!("expected error"),
1855            };
1856            assert!(
1857                err_msg.contains("Shape mismatch") || err_msg.contains("embed_tokens"),
1858                "Error should indicate shape issue: {err_msg}"
1859            );
1860        }
1861
1862        // -----------------------------------------------------------------
1863        // Error case: NaN in weights
1864        // -----------------------------------------------------------------
1865
1866        #[test]
1867        fn test_ssc024_from_safetensors_nan_detection() {
1868            let dir = TempDir::new().expect("create temp dir");
1869            let config = TransformerConfig::tiny();
1870            let hidden = config.hidden_size;
1871            let kv_hidden = config.num_kv_heads * config.head_dim();
1872            let intermediate = config.intermediate_size;
1873            let vocab = config.vocab_size;
1874
1875            let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1876
1877            let make_f32 = |n: usize, val: f32| -> Vec<u8> {
1878                std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
1879            };
1880
1881            // Embedding with NaN injected
1882            let mut embed_vals: Vec<f32> = vec![0.01; vocab * hidden];
1883            embed_vals[42] = f32::NAN;
1884            let embed_bytes: Vec<u8> = embed_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1885
1886            td.push(("model.embed_tokens.weight".to_string(), embed_bytes, vec![vocab, hidden]));
1887            td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
1888
1889            for i in 0..config.num_hidden_layers {
1890                let p = format!("model.layers.{i}");
1891                td.push((
1892                    format!("{p}.input_layernorm.weight"),
1893                    make_f32(hidden, 1.0),
1894                    vec![hidden],
1895                ));
1896                td.push((
1897                    format!("{p}.post_attention_layernorm.weight"),
1898                    make_f32(hidden, 1.0),
1899                    vec![hidden],
1900                ));
1901                td.push((
1902                    format!("{p}.self_attn.q_proj.weight"),
1903                    make_f32(hidden * hidden, 0.01),
1904                    vec![hidden, hidden],
1905                ));
1906                td.push((
1907                    format!("{p}.self_attn.k_proj.weight"),
1908                    make_f32(hidden * kv_hidden, 0.01),
1909                    vec![kv_hidden, hidden],
1910                ));
1911                td.push((
1912                    format!("{p}.self_attn.v_proj.weight"),
1913                    make_f32(hidden * kv_hidden, 0.01),
1914                    vec![kv_hidden, hidden],
1915                ));
1916                td.push((
1917                    format!("{p}.self_attn.o_proj.weight"),
1918                    make_f32(hidden * hidden, 0.01),
1919                    vec![hidden, hidden],
1920                ));
1921                td.push((
1922                    format!("{p}.mlp.gate_proj.weight"),
1923                    make_f32(hidden * intermediate, 0.01),
1924                    vec![intermediate, hidden],
1925                ));
1926                td.push((
1927                    format!("{p}.mlp.up_proj.weight"),
1928                    make_f32(hidden * intermediate, 0.01),
1929                    vec![intermediate, hidden],
1930                ));
1931                td.push((
1932                    format!("{p}.mlp.down_proj.weight"),
1933                    make_f32(intermediate * hidden, 0.01),
1934                    vec![hidden, intermediate],
1935                ));
1936            }
1937
1938            let views: Vec<TensorView<'_>> = td
1939                .iter()
1940                .map(|(_, bytes, shape)| {
1941                    TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
1942                })
1943                .collect();
1944            let named: Vec<(&str, &TensorView<'_>)> =
1945                td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
1946
1947            let file_path = dir.path().join("model.safetensors");
1948            let serialized =
1949                serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
1950            std::fs::write(&file_path, serialized).expect("write");
1951
1952            let result = Transformer::from_safetensors(dir.path(), &config);
1953            assert!(result.is_err(), "NaN in weights should fail");
1954            let err_msg = match result {
1955                Err(e) => e.to_string(),
1956                Ok(_) => panic!("expected error"),
1957            };
1958            assert!(err_msg.contains("NaN"), "Error should mention NaN: {err_msg}");
1959        }
1960
1961        // -----------------------------------------------------------------
1962        // Error case: Inf in weights
1963        // -----------------------------------------------------------------
1964
1965        #[test]
1966        fn test_ssc024_from_safetensors_inf_detection() {
1967            let dir = TempDir::new().expect("create temp dir");
1968            let config = TransformerConfig::tiny();
1969            let hidden = config.hidden_size;
1970            let kv_hidden = config.num_kv_heads * config.head_dim();
1971            let intermediate = config.intermediate_size;
1972            let vocab = config.vocab_size;
1973
1974            let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
1975
1976            let make_f32 = |n: usize, val: f32| -> Vec<u8> {
1977                std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
1978            };
1979
1980            // norm with Inf injected
1981            let mut norm_vals: Vec<f32> = vec![1.0; hidden];
1982            norm_vals[0] = f32::INFINITY;
1983            let norm_bytes: Vec<u8> = norm_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1984
1985            td.push((
1986                "model.embed_tokens.weight".to_string(),
1987                make_f32(vocab * hidden, 0.01),
1988                vec![vocab, hidden],
1989            ));
1990            td.push(("model.norm.weight".to_string(), norm_bytes, vec![hidden]));
1991
1992            for i in 0..config.num_hidden_layers {
1993                let p = format!("model.layers.{i}");
1994                td.push((
1995                    format!("{p}.input_layernorm.weight"),
1996                    make_f32(hidden, 1.0),
1997                    vec![hidden],
1998                ));
1999                td.push((
2000                    format!("{p}.post_attention_layernorm.weight"),
2001                    make_f32(hidden, 1.0),
2002                    vec![hidden],
2003                ));
2004                td.push((
2005                    format!("{p}.self_attn.q_proj.weight"),
2006                    make_f32(hidden * hidden, 0.01),
2007                    vec![hidden, hidden],
2008                ));
2009                td.push((
2010                    format!("{p}.self_attn.k_proj.weight"),
2011                    make_f32(hidden * kv_hidden, 0.01),
2012                    vec![kv_hidden, hidden],
2013                ));
2014                td.push((
2015                    format!("{p}.self_attn.v_proj.weight"),
2016                    make_f32(hidden * kv_hidden, 0.01),
2017                    vec![kv_hidden, hidden],
2018                ));
2019                td.push((
2020                    format!("{p}.self_attn.o_proj.weight"),
2021                    make_f32(hidden * hidden, 0.01),
2022                    vec![hidden, hidden],
2023                ));
2024                td.push((
2025                    format!("{p}.mlp.gate_proj.weight"),
2026                    make_f32(hidden * intermediate, 0.01),
2027                    vec![intermediate, hidden],
2028                ));
2029                td.push((
2030                    format!("{p}.mlp.up_proj.weight"),
2031                    make_f32(hidden * intermediate, 0.01),
2032                    vec![intermediate, hidden],
2033                ));
2034                td.push((
2035                    format!("{p}.mlp.down_proj.weight"),
2036                    make_f32(intermediate * hidden, 0.01),
2037                    vec![hidden, intermediate],
2038                ));
2039            }
2040
2041            let views: Vec<TensorView<'_>> = td
2042                .iter()
2043                .map(|(_, bytes, shape)| {
2044                    TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
2045                })
2046                .collect();
2047            let named: Vec<(&str, &TensorView<'_>)> =
2048                td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
2049
2050            let file_path = dir.path().join("model.safetensors");
2051            let serialized =
2052                serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
2053            std::fs::write(&file_path, serialized).expect("write");
2054
2055            let result = Transformer::from_safetensors(dir.path(), &config);
2056            assert!(result.is_err(), "Inf in weights should fail");
2057            let err_msg = match result {
2058                Err(e) => e.to_string(),
2059                Ok(_) => panic!("expected error"),
2060            };
2061            assert!(err_msg.contains("Inf"), "Error should mention Inf: {err_msg}");
2062        }
2063
2064        // -----------------------------------------------------------------
2065        // Error case: missing layer weights (wrong layer count)
2066        // -----------------------------------------------------------------
2067
2068        #[test]
2069        fn test_ssc024_from_safetensors_missing_layer() {
2070            let dir = TempDir::new().expect("create temp dir");
2071            // Create a file with 2 layers of weights
2072            create_tiny_safetensors(dir.path());
2073
2074            // But try to load with config expecting 3 layers
2075            let mut config = TransformerConfig::tiny();
2076            config.num_hidden_layers = 3;
2077
2078            let result = Transformer::from_safetensors(dir.path(), &config);
2079            assert!(result.is_err(), "Missing layer 2 should fail");
2080            let err_msg = match result {
2081                Err(e) => e.to_string(),
2082                Ok(_) => panic!("expected error"),
2083            };
2084            assert!(
2085                err_msg.contains("Missing") || err_msg.contains("layers.2"),
2086                "Error should mention missing layer: {err_msg}"
2087            );
2088        }
2089
2090        // -----------------------------------------------------------------
2091        // Error case: wrong attention projection shape
2092        // -----------------------------------------------------------------
2093
2094        #[test]
2095        fn test_ssc024_from_safetensors_wrong_q_proj_shape() {
2096            let dir = TempDir::new().expect("create temp dir");
2097            let config = TransformerConfig::tiny();
2098            let hidden = config.hidden_size;
2099            let q_dim = config.q_dim();
2100            let kv_hidden = config.num_kv_heads * config.head_dim();
2101            let intermediate = config.intermediate_size;
2102            let vocab = config.vocab_size;
2103
2104            let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
2105
2106            let make_f32 = |n: usize, val: f32| -> Vec<u8> {
2107                std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
2108            };
2109
2110            td.push((
2111                "model.embed_tokens.weight".to_string(),
2112                make_f32(vocab * hidden, 0.01),
2113                vec![vocab, hidden],
2114            ));
2115            td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
2116
2117            for i in 0..config.num_hidden_layers {
2118                let p = format!("model.layers.{i}");
2119                td.push((
2120                    format!("{p}.input_layernorm.weight"),
2121                    make_f32(hidden, 1.0),
2122                    vec![hidden],
2123                ));
2124                td.push((
2125                    format!("{p}.post_attention_layernorm.weight"),
2126                    make_f32(hidden, 1.0),
2127                    vec![hidden],
2128                ));
2129
2130                // WRONG: q_proj has 7 elements instead of q_dim*hidden
2131                if i == 0 {
2132                    td.push((format!("{p}.self_attn.q_proj.weight"), make_f32(7, 0.01), vec![7]));
2133                } else {
2134                    td.push((
2135                        format!("{p}.self_attn.q_proj.weight"),
2136                        make_f32(q_dim * hidden, 0.01),
2137                        vec![q_dim, hidden],
2138                    ));
2139                }
2140                td.push((
2141                    format!("{p}.self_attn.k_proj.weight"),
2142                    make_f32(kv_hidden * hidden, 0.01),
2143                    vec![kv_hidden, hidden],
2144                ));
2145                td.push((
2146                    format!("{p}.self_attn.v_proj.weight"),
2147                    make_f32(kv_hidden * hidden, 0.01),
2148                    vec![kv_hidden, hidden],
2149                ));
2150                td.push((
2151                    format!("{p}.self_attn.o_proj.weight"),
2152                    make_f32(hidden * q_dim, 0.01),
2153                    vec![hidden, q_dim],
2154                ));
2155                td.push((
2156                    format!("{p}.mlp.gate_proj.weight"),
2157                    make_f32(hidden * intermediate, 0.01),
2158                    vec![intermediate, hidden],
2159                ));
2160                td.push((
2161                    format!("{p}.mlp.up_proj.weight"),
2162                    make_f32(hidden * intermediate, 0.01),
2163                    vec![intermediate, hidden],
2164                ));
2165                td.push((
2166                    format!("{p}.mlp.down_proj.weight"),
2167                    make_f32(intermediate * hidden, 0.01),
2168                    vec![hidden, intermediate],
2169                ));
2170            }
2171
2172            let views: Vec<TensorView<'_>> = td
2173                .iter()
2174                .map(|(_, bytes, shape)| {
2175                    TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
2176                })
2177                .collect();
2178            let named: Vec<(&str, &TensorView<'_>)> =
2179                td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
2180
2181            let file_path = dir.path().join("model.safetensors");
2182            let serialized =
2183                serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
2184            std::fs::write(&file_path, serialized).expect("write");
2185
2186            let result = Transformer::from_safetensors(dir.path(), &config);
2187            assert!(result.is_err(), "Wrong q_proj shape should fail");
2188            let err_msg = match result {
2189                Err(e) => e.to_string(),
2190                Ok(_) => panic!("expected error"),
2191            };
2192            assert!(
2193                err_msg.contains("Shape mismatch") && err_msg.contains("q_proj"),
2194                "Error should mention q_proj shape mismatch: {err_msg}"
2195            );
2196        }
2197
2198        // -----------------------------------------------------------------
2199        // Validate weight_shapes helper independently
2200        // -----------------------------------------------------------------
2201
2202        #[test]
2203        fn test_ssc024_validate_weight_shapes_success() {
2204            let config = TransformerConfig::tiny();
2205            let hidden = config.hidden_size;
2206            let kv_hidden = config.num_kv_heads * config.head_dim();
2207            let intermediate = config.intermediate_size;
2208            let vocab = config.vocab_size;
2209
2210            let mut weights = HashMap::new();
2211            weights.insert(
2212                "model.embed_tokens.weight".to_string(),
2213                Tensor::from_vec(vec![0.1; vocab * hidden], true),
2214            );
2215            weights
2216                .insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
2217
2218            for i in 0..config.num_hidden_layers {
2219                let p = format!("model.layers.{i}");
2220                weights.insert(
2221                    format!("{p}.input_layernorm.weight"),
2222                    Tensor::from_vec(vec![1.0; hidden], true),
2223                );
2224                weights.insert(
2225                    format!("{p}.post_attention_layernorm.weight"),
2226                    Tensor::from_vec(vec![1.0; hidden], true),
2227                );
2228                weights.insert(
2229                    format!("{p}.self_attn.q_proj.weight"),
2230                    Tensor::from_vec(vec![0.1; hidden * hidden], true),
2231                );
2232                weights.insert(
2233                    format!("{p}.self_attn.k_proj.weight"),
2234                    Tensor::from_vec(vec![0.1; hidden * kv_hidden], true),
2235                );
2236                weights.insert(
2237                    format!("{p}.self_attn.v_proj.weight"),
2238                    Tensor::from_vec(vec![0.1; hidden * kv_hidden], true),
2239                );
2240                weights.insert(
2241                    format!("{p}.self_attn.o_proj.weight"),
2242                    Tensor::from_vec(vec![0.1; hidden * hidden], true),
2243                );
2244                weights.insert(
2245                    format!("{p}.mlp.gate_proj.weight"),
2246                    Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2247                );
2248                weights.insert(
2249                    format!("{p}.mlp.up_proj.weight"),
2250                    Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2251                );
2252                weights.insert(
2253                    format!("{p}.mlp.down_proj.weight"),
2254                    Tensor::from_vec(vec![0.1; intermediate * hidden], true),
2255                );
2256            }
2257
2258            let result = Transformer::validate_weight_shapes(&weights, &config);
2259            assert!(
2260                result.is_ok(),
2261                "Valid shapes should pass: {}",
2262                result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
2263            );
2264        }
2265
2266        #[test]
2267        fn test_ssc024_validate_weight_shapes_wrong_norm() {
2268            let config = TransformerConfig::tiny();
2269            let hidden = config.hidden_size;
2270            let vocab = config.vocab_size;
2271
2272            let mut weights = HashMap::new();
2273            weights.insert(
2274                "model.embed_tokens.weight".to_string(),
2275                Tensor::from_vec(vec![0.1; vocab * hidden], true),
2276            );
2277            // Wrong norm size: 3 instead of hidden
2278            weights.insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; 3], true));
2279
2280            let result = Transformer::validate_weight_shapes(&weights, &config);
2281            assert!(result.is_err());
2282            let err_msg = match result {
2283                Err(e) => e.to_string(),
2284                Ok(()) => panic!("expected error"),
2285            };
2286            assert!(err_msg.contains("model.norm.weight"));
2287        }
2288
2289        // -----------------------------------------------------------------
2290        // Validate weight_values helper independently
2291        // -----------------------------------------------------------------
2292
2293        #[test]
2294        fn test_ssc024_validate_weight_values_clean() {
2295            let mut weights = HashMap::new();
2296            weights.insert("a".to_string(), Tensor::from_vec(vec![0.1, 0.2, 0.3], true));
2297            weights.insert("b".to_string(), Tensor::from_vec(vec![1.0, -1.0, 0.0], true));
2298
2299            let result = Transformer::validate_weight_values(&weights);
2300            assert!(result.is_ok());
2301        }
2302
2303        #[test]
2304        fn test_ssc024_validate_weight_values_nan() {
2305            let mut weights = HashMap::new();
2306            weights.insert("clean".to_string(), Tensor::from_vec(vec![0.1, 0.2], true));
2307            weights
2308                .insert("poisoned".to_string(), Tensor::from_vec(vec![0.1, f32::NAN, 0.3], true));
2309
2310            let result = Transformer::validate_weight_values(&weights);
2311            assert!(result.is_err());
2312            let err_msg = match result {
2313                Err(e) => e.to_string(),
2314                Ok(()) => panic!("expected error"),
2315            };
2316            assert!(err_msg.contains("NaN"));
2317            assert!(err_msg.contains("poisoned"));
2318        }
2319
2320        #[test]
2321        fn test_ssc024_validate_weight_values_inf() {
2322            let mut weights = HashMap::new();
2323            weights.insert("w".to_string(), Tensor::from_vec(vec![f32::NEG_INFINITY, 0.2], true));
2324
2325            let result = Transformer::validate_weight_values(&weights);
2326            assert!(result.is_err());
2327            let err_msg = match result {
2328                Err(e) => e.to_string(),
2329                Ok(()) => panic!("expected error"),
2330            };
2331            assert!(err_msg.contains("Inf"));
2332        }
2333
2334        /// GH-262: Qwen3-4B shape validation with q_dim != hidden_size.
2335        ///
2336        /// Uses a 1-layer config mimicking Qwen3-4B dimensions to verify
2337        /// validate_weight_shapes accepts the correct q_dim-based shapes.
2338        #[test]
2339        fn test_gh262_qwen3_4b_weight_shapes_q_dim_ne_hidden() {
2340            // Minimal Qwen3-like config: q_dim (128) != hidden_size (80)
2341            let config = TransformerConfig {
2342                hidden_size: 80,
2343                num_attention_heads: 4,
2344                num_kv_heads: 2,
2345                intermediate_size: 128,
2346                num_hidden_layers: 1,
2347                vocab_size: 256,
2348                max_position_embeddings: 512,
2349                rms_norm_eps: 1e-6,
2350                rope_theta: 10000.0,
2351                use_bias: false,
2352                head_dim_override: Some(32), // head_dim=32, so q_dim = 4*32 = 128, hidden=80
2353                architecture: crate::transformer::config::ModelArchitecture::Decoder,
2354                hf_architecture: None,
2355                hf_model_type: None,
2356                tie_word_embeddings: false,
2357            };
2358
2359            let hidden = config.hidden_size; // 80
2360            let q_dim = config.q_dim(); // 4 * 32 = 128
2361            let kv_hidden = config.num_kv_heads * config.head_dim(); // 2 * 32 = 64
2362            let intermediate = config.intermediate_size; // 128
2363            let vocab = config.vocab_size; // 256
2364
2365            // Verify q_dim != hidden (the Qwen3-4B characteristic)
2366            assert_ne!(q_dim, hidden, "test requires q_dim != hidden_size");
2367
2368            let mut weights = HashMap::new();
2369            weights.insert(
2370                "model.embed_tokens.weight".to_string(),
2371                Tensor::from_vec(vec![0.1; vocab * hidden], true),
2372            );
2373            weights
2374                .insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
2375
2376            let p = "model.layers.0";
2377            weights.insert(
2378                format!("{p}.input_layernorm.weight"),
2379                Tensor::from_vec(vec![1.0; hidden], true),
2380            );
2381            weights.insert(
2382                format!("{p}.post_attention_layernorm.weight"),
2383                Tensor::from_vec(vec![1.0; hidden], true),
2384            );
2385            // Q: [q_dim, hidden] — NOT [hidden, hidden]
2386            weights.insert(
2387                format!("{p}.self_attn.q_proj.weight"),
2388                Tensor::from_vec(vec![0.1; q_dim * hidden], true),
2389            );
2390            // K: [kv_hidden, hidden]
2391            weights.insert(
2392                format!("{p}.self_attn.k_proj.weight"),
2393                Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
2394            );
2395            // V: [kv_hidden, hidden]
2396            weights.insert(
2397                format!("{p}.self_attn.v_proj.weight"),
2398                Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
2399            );
2400            // O: [hidden, q_dim] — NOT [hidden, hidden]
2401            weights.insert(
2402                format!("{p}.self_attn.o_proj.weight"),
2403                Tensor::from_vec(vec![0.1; hidden * q_dim], true),
2404            );
2405            weights.insert(
2406                format!("{p}.mlp.gate_proj.weight"),
2407                Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2408            );
2409            weights.insert(
2410                format!("{p}.mlp.up_proj.weight"),
2411                Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2412            );
2413            weights.insert(
2414                format!("{p}.mlp.down_proj.weight"),
2415                Tensor::from_vec(vec![0.1; intermediate * hidden], true),
2416            );
2417
2418            // Should pass: shapes use q_dim for Q/O projections
2419            let result = Transformer::validate_weight_shapes(&weights, &config);
2420            assert!(
2421                result.is_ok(),
2422                "Qwen3-like shapes (q_dim={q_dim} != hidden={hidden}) should validate: {:?}",
2423                result.err()
2424            );
2425
2426            // Should also construct successfully via from_params
2427            let model = Transformer::from_params(&config, &weights);
2428            assert!(model.is_some(), "Qwen3-like model with q_dim != hidden should construct");
2429        }
2430
2431        /// GH-262: Using hidden_size instead of q_dim for q_proj must fail.
2432        #[test]
2433        fn test_gh262_wrong_q_proj_size_hidden_instead_of_q_dim() {
2434            let config = TransformerConfig {
2435                hidden_size: 80,
2436                num_attention_heads: 4,
2437                num_kv_heads: 2,
2438                intermediate_size: 128,
2439                num_hidden_layers: 1,
2440                vocab_size: 256,
2441                max_position_embeddings: 512,
2442                rms_norm_eps: 1e-6,
2443                rope_theta: 10000.0,
2444                use_bias: false,
2445                head_dim_override: Some(32), // q_dim=128, hidden=80
2446                architecture: crate::transformer::config::ModelArchitecture::Decoder,
2447                hf_architecture: None,
2448                hf_model_type: None,
2449                tie_word_embeddings: false,
2450            };
2451
2452            let hidden = config.hidden_size; // 80
2453            let kv_hidden = config.num_kv_heads * config.head_dim(); // 64
2454            let intermediate = config.intermediate_size;
2455            let vocab = config.vocab_size;
2456
2457            let mut weights = HashMap::new();
2458            weights.insert(
2459                "model.embed_tokens.weight".to_string(),
2460                Tensor::from_vec(vec![0.1; vocab * hidden], true),
2461            );
2462            weights
2463                .insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
2464
2465            let p = "model.layers.0";
2466            weights.insert(
2467                format!("{p}.input_layernorm.weight"),
2468                Tensor::from_vec(vec![1.0; hidden], true),
2469            );
2470            weights.insert(
2471                format!("{p}.post_attention_layernorm.weight"),
2472                Tensor::from_vec(vec![1.0; hidden], true),
2473            );
2474            // BUG: q_proj uses hidden*hidden (6400) instead of q_dim*hidden (10240)
2475            weights.insert(
2476                format!("{p}.self_attn.q_proj.weight"),
2477                Tensor::from_vec(vec![0.1; hidden * hidden], true),
2478            );
2479            weights.insert(
2480                format!("{p}.self_attn.k_proj.weight"),
2481                Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
2482            );
2483            weights.insert(
2484                format!("{p}.self_attn.v_proj.weight"),
2485                Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
2486            );
2487            weights.insert(
2488                format!("{p}.self_attn.o_proj.weight"),
2489                Tensor::from_vec(vec![0.1; hidden * hidden], true),
2490            );
2491            weights.insert(
2492                format!("{p}.mlp.gate_proj.weight"),
2493                Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2494            );
2495            weights.insert(
2496                format!("{p}.mlp.up_proj.weight"),
2497                Tensor::from_vec(vec![0.1; hidden * intermediate], true),
2498            );
2499            weights.insert(
2500                format!("{p}.mlp.down_proj.weight"),
2501                Tensor::from_vec(vec![0.1; intermediate * hidden], true),
2502            );
2503
2504            // Must fail: q_proj has wrong size
2505            let result = Transformer::validate_weight_shapes(&weights, &config);
2506            assert!(result.is_err(), "hidden*hidden q_proj should fail when q_dim != hidden");
2507            let err_msg = result.err().map(|e| e.to_string()).unwrap_or_default();
2508            assert!(
2509                err_msg.contains("q_proj") && err_msg.contains("Shape mismatch"),
2510                "Error should mention q_proj shape mismatch, got: {err_msg}"
2511            );
2512        }
2513
2514        // -----------------------------------------------------------------
2515        // Name mapping integration: Qwen2 bias tensors are preserved
2516        // -----------------------------------------------------------------
2517
2518        #[test]
2519        fn test_ssc024_from_safetensors_with_extra_bias_tensors() {
2520            // Qwen2 models have bias tensors that are loaded alongside weights.
2521            // from_params ignores them (doesn't look for bias keys), but they
2522            // should not cause errors.
2523            let dir = TempDir::new().expect("create temp dir");
2524            let config = TransformerConfig::tiny();
2525            let hidden = config.hidden_size;
2526            let kv_hidden = config.num_kv_heads * config.head_dim();
2527            let intermediate = config.intermediate_size;
2528            let vocab = config.vocab_size;
2529
2530            let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
2531
2532            let make_f32 = |n: usize, val: f32| -> Vec<u8> {
2533                std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
2534            };
2535
2536            td.push((
2537                "model.embed_tokens.weight".to_string(),
2538                make_f32(vocab * hidden, 0.01),
2539                vec![vocab, hidden],
2540            ));
2541            td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
2542
2543            for i in 0..config.num_hidden_layers {
2544                let p = format!("model.layers.{i}");
2545                td.push((
2546                    format!("{p}.input_layernorm.weight"),
2547                    make_f32(hidden, 1.0),
2548                    vec![hidden],
2549                ));
2550                td.push((
2551                    format!("{p}.post_attention_layernorm.weight"),
2552                    make_f32(hidden, 1.0),
2553                    vec![hidden],
2554                ));
2555                td.push((
2556                    format!("{p}.self_attn.q_proj.weight"),
2557                    make_f32(hidden * hidden, 0.01),
2558                    vec![hidden, hidden],
2559                ));
2560                td.push((
2561                    format!("{p}.self_attn.k_proj.weight"),
2562                    make_f32(hidden * kv_hidden, 0.01),
2563                    vec![kv_hidden, hidden],
2564                ));
2565                td.push((
2566                    format!("{p}.self_attn.v_proj.weight"),
2567                    make_f32(hidden * kv_hidden, 0.01),
2568                    vec![kv_hidden, hidden],
2569                ));
2570                td.push((
2571                    format!("{p}.self_attn.o_proj.weight"),
2572                    make_f32(hidden * hidden, 0.01),
2573                    vec![hidden, hidden],
2574                ));
2575                td.push((
2576                    format!("{p}.mlp.gate_proj.weight"),
2577                    make_f32(hidden * intermediate, 0.01),
2578                    vec![intermediate, hidden],
2579                ));
2580                td.push((
2581                    format!("{p}.mlp.up_proj.weight"),
2582                    make_f32(hidden * intermediate, 0.01),
2583                    vec![intermediate, hidden],
2584                ));
2585                td.push((
2586                    format!("{p}.mlp.down_proj.weight"),
2587                    make_f32(intermediate * hidden, 0.01),
2588                    vec![hidden, intermediate],
2589                ));
2590
2591                // Qwen2-style bias tensors
2592                td.push((
2593                    format!("{p}.self_attn.q_proj.bias"),
2594                    make_f32(hidden, 0.0),
2595                    vec![hidden],
2596                ));
2597                td.push((
2598                    format!("{p}.self_attn.k_proj.bias"),
2599                    make_f32(kv_hidden, 0.0),
2600                    vec![kv_hidden],
2601                ));
2602                td.push((
2603                    format!("{p}.self_attn.v_proj.bias"),
2604                    make_f32(kv_hidden, 0.0),
2605                    vec![kv_hidden],
2606                ));
2607            }
2608
2609            let views: Vec<TensorView<'_>> = td
2610                .iter()
2611                .map(|(_, bytes, shape)| {
2612                    TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
2613                })
2614                .collect();
2615            let named: Vec<(&str, &TensorView<'_>)> =
2616                td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
2617
2618            let file_path = dir.path().join("model.safetensors");
2619            let serialized =
2620                serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
2621            std::fs::write(&file_path, serialized).expect("write");
2622
2623            // Should succeed even with extra bias tensors
2624            let result = Transformer::from_safetensors(dir.path(), &config);
2625            assert!(
2626                result.is_ok(),
2627                "Extra bias tensors should not cause failure: {}",
2628                result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
2629            );
2630        }
2631    }
2632}