use crate::autograd::matmul_nt;
use crate::error::{Error, Result};
use crate::Tensor;
use provable_contracts_macros::{ensures, requires};
use std::collections::HashMap;
use std::path::Path;
use super::block::TransformerBlock;
use super::config::TransformerConfig;
use super::embedding::Embedding;
use super::norm::RMSNorm;
use super::weights::{load_safetensors_weights, validate_weights, Architecture};
pub struct Transformer {
pub config: TransformerConfig,
pub embed_tokens: Embedding,
pub layers: Vec<TransformerBlock>,
pub norm: RMSNorm,
pub lm_head: Option<Tensor>,
}
impl Transformer {
pub fn new(config: &TransformerConfig) -> Self {
let layers: Vec<TransformerBlock> =
(0..config.num_hidden_layers).map(|i| TransformerBlock::new(config, i)).collect();
Self {
config: config.clone(),
embed_tokens: Embedding::new(config.vocab_size, config.hidden_size),
layers,
norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
lm_head: None, }
}
pub fn from_params(
config: &TransformerConfig,
params: &HashMap<String, Tensor>,
) -> Option<Self> {
let embed_tokens = Embedding::from_params(
params,
"model.embed_tokens.weight",
config.vocab_size,
config.hidden_size,
)?;
let layers: Option<Vec<TransformerBlock>> = (0..config.num_hidden_layers)
.map(|i| TransformerBlock::from_params(config, params, i))
.collect();
let layers = layers?;
let norm =
RMSNorm::from_params(params, "model.norm", config.rms_norm_eps, config.hidden_size)?;
let lm_head = if let Some(tensor) = params.get("lm_head.weight") {
let expected = config.hidden_size * config.vocab_size;
if tensor.len() != expected {
eprintln!(
"[PMAT-329] lm_head.weight: shape mismatch — got {} elements, expected {expected} ({hidden}x{vocab})",
tensor.len(),
hidden = config.hidden_size,
vocab = config.vocab_size,
);
return None;
}
Some(tensor.clone())
} else {
None
};
Some(Self { config: config.clone(), embed_tokens, layers, norm, lm_head })
}
pub fn from_safetensors(
model_path: impl AsRef<Path>,
config: &TransformerConfig,
) -> Result<Self> {
let model_path = model_path.as_ref();
let weights = load_safetensors_weights(model_path, Architecture::Auto)?;
validate_weights(&weights, config.num_hidden_layers)?;
Self::validate_weight_shapes(&weights, config)?;
Self::validate_weight_values(&weights)?;
Self::from_params(config, &weights).ok_or_else(|| {
Error::ConfigError(
"Failed to construct Transformer from loaded weights \
(internal from_params returned None after validation passed)"
.into(),
)
})
}
pub fn from_apr(apr_path: impl AsRef<Path>, config: &TransformerConfig) -> Result<Self> {
use aprender::serialization::apr::AprReader;
let apr_path = apr_path.as_ref();
let reader = AprReader::open(apr_path).map_err(|e| {
Error::ConfigError(format!("Failed to open APR file '{}': {e}", apr_path.display()))
})?;
let is_gguf_names = reader.tensors.iter().any(|t| t.name == "token_embd.weight");
if is_gguf_names {
eprintln!(
"[PMAT-489] Detected GGUF tensor names in APR file, mapping to HF convention"
);
}
let mut weights = HashMap::new();
for desc in &reader.tensors {
let data = reader.read_tensor_as_f32(&desc.name).map_err(|e| {
Error::ConfigError(format!("Failed to read tensor '{}': {e}", desc.name))
})?;
let mapped_name = if is_gguf_names {
super::weights::mapping::map_weight_name(
&desc.name,
super::weights::Architecture::Gguf,
)
} else {
desc.name.clone()
};
weights.insert(mapped_name, Tensor::from_vec(data, false));
}
validate_weights(&weights, config.num_hidden_layers)?;
Self::validate_weight_shapes(&weights, config)?;
Self::validate_weight_values(&weights)?;
Self::from_params(config, &weights).ok_or_else(|| {
Error::ConfigError(
"Failed to construct Transformer from APR weights \
(from_params returned None after validation passed)"
.into(),
)
})
}
fn validate_weight_shapes(
weights: &HashMap<String, Tensor>,
config: &TransformerConfig,
) -> Result<()> {
let hidden = config.hidden_size;
let q_dim = config.q_dim();
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let check = |name: &str, expected: usize| -> Result<()> {
if let Some(tensor) = weights.get(name) {
if tensor.len() != expected {
return Err(Error::ConfigError(format!(
"Shape mismatch for '{name}': expected {expected} elements, got {}",
tensor.len()
)));
}
}
Ok(())
};
check("model.embed_tokens.weight", vocab * hidden)?;
check("model.norm.weight", hidden)?;
if weights.contains_key("lm_head.weight") {
check("lm_head.weight", vocab * hidden)?;
}
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
check(&format!("{p}.input_layernorm.weight"), hidden)?;
check(&format!("{p}.post_attention_layernorm.weight"), hidden)?;
check(&format!("{p}.self_attn.q_proj.weight"), q_dim * hidden)?;
check(&format!("{p}.self_attn.k_proj.weight"), kv_hidden * hidden)?;
check(&format!("{p}.self_attn.v_proj.weight"), kv_hidden * hidden)?;
check(&format!("{p}.self_attn.o_proj.weight"), hidden * q_dim)?;
check(&format!("{p}.self_attn.q_proj.bias"), q_dim)?;
check(&format!("{p}.self_attn.k_proj.bias"), kv_hidden)?;
check(&format!("{p}.self_attn.v_proj.bias"), kv_hidden)?;
check(&format!("{p}.mlp.gate_proj.weight"), hidden * intermediate)?;
check(&format!("{p}.mlp.up_proj.weight"), hidden * intermediate)?;
check(&format!("{p}.mlp.down_proj.weight"), intermediate * hidden)?;
}
Ok(())
}
fn validate_weight_values(weights: &HashMap<String, Tensor>) -> Result<()> {
for (name, tensor) in weights {
let data = tensor.data();
for (i, &val) in data.iter().enumerate() {
if val.is_nan() {
return Err(Error::ConfigError(format!(
"NaN detected in weight '{name}' at index {i}"
)));
}
if val.is_infinite() {
return Err(Error::ConfigError(format!(
"Inf detected in weight '{name}' at index {i}"
)));
}
}
}
Ok(())
}
#[requires(!token_ids.is_empty())]
#[ensures(ret.len() == token_ids.len() * self.config.vocab_size)]
pub fn forward(&self, token_ids: &[u32]) -> Tensor {
contract_pre_embedding_lookup!(token_ids);
let seq_len = token_ids.len();
let hidden_size = self.config.hidden_size;
let mut hidden = self.embed_tokens.forward(token_ids);
for layer in &self.layers {
hidden = layer.forward(&hidden, seq_len);
}
let normalized = self.norm.forward_batched(&hidden, seq_len, hidden_size);
let lm_weight = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
matmul_nt(&normalized, lm_weight, seq_len, hidden_size, self.config.vocab_size)
}
#[requires(!token_ids.is_empty())]
#[ensures(ret.len() == token_ids.len() * self.config.hidden_size)]
pub fn forward_hidden(&self, token_ids: &[u32]) -> Tensor {
contract_pre_embedding_lookup!(token_ids);
let seq_len = token_ids.len();
let hidden_size = self.config.hidden_size;
let mut hidden = self.embed_tokens.forward(token_ids);
for layer in &self.layers {
hidden = layer.forward(&hidden, seq_len);
}
self.norm.forward_batched(&hidden, seq_len, hidden_size)
}
pub fn forward_hidden_with_lora(
&self,
token_ids: &[u32],
lora_layers: &[crate::lora::LoRALayer],
) -> Tensor {
contract_pre_embedding_lookup!(token_ids);
let seq_len = token_ids.len();
let hidden_size = self.config.hidden_size;
let mut hidden = self.embed_tokens.forward(token_ids);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let norm1 = layer.input_norm.forward_batched(&hidden, seq_len, hidden_size);
let q_idx = layer_idx * 2;
let v_idx = layer_idx * 2 + 1;
let attn_out = if v_idx < lora_layers.len() {
layer.self_attn.forward_with_lora(
&norm1,
seq_len,
lora_layers[q_idx].lora_a(),
lora_layers[q_idx].lora_b(),
lora_layers[v_idx].lora_a(),
lora_layers[v_idx].lora_b(),
lora_layers[q_idx].rank(),
lora_layers[q_idx].scale(),
)
} else {
layer.self_attn.forward(&norm1, seq_len)
};
let residual = crate::autograd::add(&hidden, &attn_out);
let norm2 = layer.post_attn_norm.forward_batched(&residual, seq_len, hidden_size);
let ffn_out = layer.ffn.forward(&norm2, seq_len);
hidden = crate::autograd::add(&residual, &ffn_out);
}
self.norm.forward_batched(&hidden, seq_len, hidden_size)
}
pub fn forward_with_lora(
&self,
token_ids: &[u32],
lora_layers: &[crate::lora::LoRALayer],
) -> Tensor {
contract_pre_embedding_lookup!(token_ids);
let seq_len = token_ids.len();
let hidden_size = self.config.hidden_size;
let hidden = self.forward_hidden_with_lora(token_ids, lora_layers);
let lm_weight = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
matmul_nt(&hidden, lm_weight, seq_len, hidden_size, self.config.vocab_size)
}
pub fn forward_last(&self, token_ids: &[u32]) -> Tensor {
contract_pre_embedding_lookup!(token_ids);
let logits = self.forward(token_ids);
let seq_len = token_ids.len();
let vocab_size = self.config.vocab_size;
let start = (seq_len - 1) * vocab_size;
let end = start + vocab_size;
let last_logits: Vec<f32> =
logits.data().as_slice().expect("logits must be contiguous")[start..end].to_vec();
Tensor::from_vec(last_logits, logits.requires_grad())
}
pub fn parameters(&self) -> Vec<&Tensor> {
let mut params = vec![&self.embed_tokens.weight, &self.norm.weight];
for layer in &self.layers {
params.extend(layer.parameters());
}
if let Some(lm_head) = &self.lm_head {
params.push(lm_head);
}
params
}
pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params: Vec<&mut Tensor> = Vec::new();
params.push(&mut self.embed_tokens.weight);
params.push(&mut self.norm.weight);
for layer in &mut self.layers {
params.extend(layer.parameters_mut());
}
if let Some(lm_head) = &mut self.lm_head {
params.push(lm_head);
}
params
}
pub fn config(&self) -> &TransformerConfig {
&self.config
}
pub fn embed_token(&self, token_id: u32) -> Vec<f32> {
let w = self.embed_tokens.weight.data();
let data = w.as_slice().expect("contiguous embedding");
let h = self.config.hidden_size;
let offset = (token_id as usize) * h;
data[offset..offset + h].to_vec()
}
pub fn output_norm_weight_slice(&self) -> &[f32] {
self.norm.weight.data().as_slice().expect("contiguous norm weight")
}
pub fn lm_head_weight_slice(&self) -> &[f32] {
let w = self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight);
w.data().as_slice().expect("contiguous lm_head")
}
pub fn lm_head_weight(&self) -> &Tensor {
self.lm_head.as_ref().unwrap_or(&self.embed_tokens.weight)
}
pub fn named_parameters(&self) -> Vec<(String, &Tensor)> {
let mut params = vec![
("model.embed_tokens.weight".to_string(), &self.embed_tokens.weight),
("model.norm.weight".to_string(), &self.norm.weight),
];
for layer in &self.layers {
params.extend(layer.named_parameters());
}
if let Some(ref lm_head) = self.lm_head {
params.push(("lm_head.weight".to_string(), lm_head));
}
params
}
pub fn set_named_parameter(&mut self, name: &str, value: Tensor) -> bool {
if name == "model.embed_tokens.weight" {
self.embed_tokens.weight = value;
return true;
}
if name == "model.norm.weight" {
self.norm.weight = value;
return true;
}
if name == "lm_head.weight" {
self.lm_head = Some(value);
return true;
}
if let Some(rest) = name.strip_prefix("model.layers.") {
if let Some(dot_pos) = rest.find('.') {
if let Ok(idx) = rest[..dot_pos].parse::<usize>() {
if idx < self.layers.len() {
let suffix = &rest[dot_pos + 1..];
return self.layers[idx].set_named_parameter(suffix, value);
}
}
}
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transformer_tiny_forward() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens = vec![1, 2, 3];
let logits = transformer.forward(&tokens);
assert_eq!(logits.len(), 3 * config.vocab_size);
}
#[test]
fn test_transformer_tiny_forward_last() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens = vec![1, 2, 3];
let logits = transformer.forward_last(&tokens);
assert_eq!(logits.len(), config.vocab_size);
}
#[test]
fn test_transformer_parameters() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let params = transformer.parameters();
assert_eq!(params.len(), 20);
}
#[test]
fn test_transformer_config_accessor() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
assert_eq!(transformer.config().hidden_size, config.hidden_size);
assert_eq!(transformer.config().vocab_size, config.vocab_size);
}
#[test]
fn test_transformer_single_token() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens = vec![42];
let logits = transformer.forward(&tokens);
assert_eq!(logits.len(), config.vocab_size);
}
#[test]
fn test_output_finite_values() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens = vec![1, 2, 3, 4, 5];
let logits = transformer.forward(&tokens);
assert!(logits.data().iter().all(|&v| v.is_finite()));
}
#[test]
fn test_transformer_empty_lm_head_uses_tied_weights() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
assert!(transformer.lm_head.is_none());
let tokens = vec![1, 2];
let logits = transformer.forward(&tokens);
assert_eq!(logits.len(), 2 * config.vocab_size);
}
#[test]
fn test_from_params_returns_none_on_missing() {
let config = TransformerConfig::tiny();
let params: HashMap<String, Tensor> = HashMap::new();
let result = Transformer::from_params(&config, ¶ms);
assert!(result.is_none());
}
#[test]
fn test_transformer_from_params_with_lm_head() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let vocab_size = config.vocab_size;
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let intermediate_size = config.intermediate_size;
let mut params = HashMap::new();
params.insert(
"model.embed_tokens.weight".to_string(),
Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
);
for layer_idx in 0..config.num_hidden_layers {
let prefix = format!("model.layers.{layer_idx}");
params.insert(
format!("{prefix}.input_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.q_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.k_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.v_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.o_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
format!("{prefix}.post_attention_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
format!("{prefix}.mlp.gate_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
);
params.insert(
format!("{prefix}.mlp.up_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
);
params.insert(
format!("{prefix}.mlp.down_proj.weight"),
Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
);
}
params.insert(
"model.norm.weight".to_string(),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
"lm_head.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * vocab_size], true),
);
let transformer = Transformer::from_params(&config, ¶ms);
assert!(transformer.is_some());
let transformer = transformer.expect("operation should succeed");
assert!(transformer.lm_head.is_some());
assert_eq!(transformer.layers.len(), config.num_hidden_layers);
}
#[test]
fn test_transformer_from_params_without_lm_head() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let vocab_size = config.vocab_size;
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let intermediate_size = config.intermediate_size;
let mut params = HashMap::new();
params.insert(
"model.embed_tokens.weight".to_string(),
Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
);
for layer_idx in 0..config.num_hidden_layers {
let prefix = format!("model.layers.{layer_idx}");
params.insert(
format!("{prefix}.input_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.q_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.k_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.v_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.o_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
format!("{prefix}.post_attention_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
format!("{prefix}.mlp.gate_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
);
params.insert(
format!("{prefix}.mlp.up_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
);
params.insert(
format!("{prefix}.mlp.down_proj.weight"),
Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
);
}
params.insert(
"model.norm.weight".to_string(),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
let transformer = Transformer::from_params(&config, ¶ms);
assert!(transformer.is_some());
let transformer = transformer.expect("operation should succeed");
assert!(transformer.lm_head.is_none()); }
#[test]
fn test_transformer_parameters_with_lm_head() {
let config = TransformerConfig::tiny();
let mut transformer = Transformer::new(&config);
transformer.lm_head =
Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
let params = transformer.parameters();
assert_eq!(params.len(), 21);
}
#[test]
fn test_transformer_forward_with_lm_head() {
let config = TransformerConfig::tiny();
let mut transformer = Transformer::new(&config);
transformer.lm_head =
Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
let tokens = vec![1, 2, 3];
let logits = transformer.forward(&tokens);
assert_eq!(logits.len(), 3 * config.vocab_size);
assert!(logits.data().iter().all(|&v| v.is_finite()));
}
#[test]
fn falsify_l1e_from_params_rejects_wrong_shape_lm_head() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let vocab_size = config.vocab_size;
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let intermediate_size = config.intermediate_size;
let mut params = HashMap::new();
params.insert(
"model.embed_tokens.weight".to_string(),
Tensor::from_vec(vec![0.1; vocab_size * hidden_size], true),
);
for layer_idx in 0..config.num_hidden_layers {
let prefix = format!("model.layers.{layer_idx}");
params.insert(
format!("{prefix}.input_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.q_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.k_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.v_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * kv_hidden_size], true),
);
params.insert(
format!("{prefix}.self_attn.o_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * hidden_size], true),
);
params.insert(
format!("{prefix}.post_attention_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
format!("{prefix}.mlp.gate_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
);
params.insert(
format!("{prefix}.mlp.up_proj.weight"),
Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
);
params.insert(
format!("{prefix}.mlp.down_proj.weight"),
Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
);
}
params.insert(
"model.norm.weight".to_string(),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert("lm_head.weight".to_string(), Tensor::from_vec(vec![0.1; 50], true));
let transformer = Transformer::from_params(&config, ¶ms);
assert!(
transformer.is_none(),
"FALSIFY-L1e: PMAT-329 fix — from_params MUST reject wrong-shape lm_head"
);
}
#[test]
fn falsify_l2e_tied_embeddings_produce_correct_logit_dims() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
assert!(transformer.lm_head.is_none(), "Default should use tied embeddings");
let tokens = vec![1, 2, 3];
let logits = transformer.forward(&tokens);
assert_eq!(
logits.len(),
3 * config.vocab_size,
"FALSIFY-L2e: Tied embedding logits must be seq_len * vocab_size"
);
let data = logits.data();
let nan_count = data.iter().filter(|v| v.is_nan()).count();
let inf_count = data.iter().filter(|v| v.is_infinite()).count();
assert_eq!(nan_count, 0, "FALSIFY-L2e: Tied logits must not contain NaN");
assert_eq!(inf_count, 0, "FALSIFY-L2e: Tied logits must not contain Inf");
}
#[test]
fn falsify_l3e_separate_lm_head_produces_correct_logit_dims() {
let config = TransformerConfig::tiny();
let mut transformer = Transformer::new(&config);
transformer.lm_head =
Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
let tokens = vec![1, 2, 3];
let logits = transformer.forward(&tokens);
assert_eq!(
logits.len(),
3 * config.vocab_size,
"FALSIFY-L3e: Separate lm_head logits must be seq_len * vocab_size"
);
let data = logits.data();
assert!(
data.iter().all(|v| v.is_finite()),
"FALSIFY-L3e: Separate lm_head logits must all be finite"
);
}
#[test]
fn falsify_l4e_lm_head_in_parameter_list() {
let config = TransformerConfig::tiny();
let mut transformer = Transformer::new(&config);
let n_without = transformer.parameters().len();
transformer.lm_head =
Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
let n_with = transformer.parameters().len();
assert_eq!(
n_with,
n_without + 1,
"FALSIFY-L4e: lm_head must be included in parameters() — optimizer needs it"
);
let n_mut = transformer.parameters_mut().len();
assert_eq!(
n_mut, n_with,
"FALSIFY-L4e: parameters_mut() must include lm_head for gradient updates"
);
}
#[test]
fn falsify_l5e_forward_last_correct_size() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens = vec![1, 2, 3, 4, 5];
let logits = transformer.forward_last(&tokens);
assert_eq!(
logits.len(),
config.vocab_size,
"FALSIFY-L5e: forward_last must return exactly vocab_size logits"
);
let data = logits.data();
assert!(
data.iter().all(|v| v.is_finite()),
"FALSIFY-L5e: forward_last logits must all be finite"
);
}
#[test]
fn test_causal_lm_loss_backward() {
use crate::train::CausalLMLoss;
use crate::train::LossFn;
let vocab_size = 100;
let seq_len = 3;
let loss_fn = CausalLMLoss::new(vocab_size);
let logits = Tensor::from_vec(
(0..seq_len * vocab_size).map(|i| (i as f32 * 0.01).sin()).collect(),
true,
);
let targets = Tensor::from_vec(vec![5.0, 10.0, 15.0], false);
let mut loss = loss_fn.forward(&logits, &targets);
crate::autograd::backward(&mut loss, None);
assert!(loss.data()[0] > 0.0);
assert!(loss.data()[0].is_finite());
assert!(logits.grad().is_some());
let grad = logits.grad().expect("gradient should be available");
assert!(grad.iter().all(|&v| v.is_finite()));
}
#[test]
fn falsify_emb_003_tied_weight_sharing() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
assert!(transformer.lm_head.is_none());
let lm_weight = transformer.lm_head.as_ref().unwrap_or(&transformer.embed_tokens.weight);
let embed_weight = &transformer.embed_tokens.weight;
assert!(
std::ptr::eq(lm_weight, embed_weight),
"FALSIFIED EMB-003: tied lm_head must be same object as embed_tokens.weight"
);
}
#[test]
fn falsify_te_001_output_shape() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
for seq_len in [1, 3, 10] {
let tokens: Vec<u32> = (0..seq_len).collect();
let logits = transformer.forward(&tokens);
assert_eq!(
logits.len(),
seq_len as usize * config.vocab_size,
"FALSIFIED TE-001: output shape for seq_len={seq_len}"
);
}
}
#[test]
fn falsify_te_002_tied_equivalence() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens = vec![0u32, 3, 7, 15, 42];
let tied_logits = transformer.forward(&tokens);
let hidden = transformer.forward_hidden(&tokens);
let w_clone = transformer.embed_tokens.weight.clone();
let explicit_logits =
matmul_nt(&hidden, &w_clone, tokens.len(), config.hidden_size, config.vocab_size);
let tied_data = tied_logits.data();
let explicit_data = explicit_logits.data();
assert_eq!(
tied_data.len(),
explicit_data.len(),
"FALSIFIED TE-002: output lengths differ: {} vs {}",
tied_data.len(),
explicit_data.len()
);
for (i, (&t, &e)) in tied_data.iter().zip(explicit_data.iter()).enumerate() {
assert!(
(t - e).abs() < 1e-6,
"FALSIFIED TE-002: tied[{i}] = {t} != explicit[{i}] = {e}"
);
}
}
#[test]
fn falsify_te_003_no_extra_params() {
let config = TransformerConfig::tiny();
let tied = Transformer::new(&config);
let tied_count = tied.parameters().len();
let mut untied = Transformer::new(&config);
untied.lm_head =
Some(Tensor::from_vec(vec![0.1; config.hidden_size * config.vocab_size], true));
let untied_count = untied.parameters().len();
assert_eq!(
untied_count,
tied_count + 1,
"FALSIFIED TE-003: tied model must have exactly 1 fewer param than untied"
);
}
#[test]
fn falsify_te_004_finite_output() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens = vec![0u32, 5, 10, 50, 99];
let logits = transformer.forward(&tokens);
let data = logits.data();
let nan_count = data.iter().filter(|v| v.is_nan()).count();
let inf_count = data.iter().filter(|v| v.is_infinite()).count();
assert_eq!(
nan_count, 0,
"FALSIFIED TE-004: tied embedding output contains {nan_count} NaN values"
);
assert_eq!(
inf_count, 0,
"FALSIFIED TE-004: tied embedding output contains {inf_count} Inf values"
);
}
mod te_proptest_falsify {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(50))]
#[test]
fn falsify_te_001_prop_output_shape(
seq_len in 1_usize..32,
) {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens: Vec<u32> = (0..seq_len).map(|i| (i % config.vocab_size) as u32).collect();
let logits = transformer.forward(&tokens);
prop_assert_eq!(
logits.len(),
seq_len * config.vocab_size,
"FALSIFIED TE-001-prop: seq_len={}, got len={}", seq_len, logits.len()
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn falsify_te_002_prop_tied_equivalence(
token_ids in proptest::collection::vec(0_u32..999, 1..8),
) {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tied_logits = transformer.forward(&token_ids);
let hidden = transformer.forward_hidden(&token_ids);
let w_clone = transformer.embed_tokens.weight.clone();
let explicit_logits = matmul_nt(
&hidden, &w_clone,
token_ids.len(), config.hidden_size, config.vocab_size,
);
let tied_data = tied_logits.data();
let explicit_data = explicit_logits.data();
prop_assert_eq!(tied_data.len(), explicit_data.len());
for (i, (&t, &e)) in tied_data.iter().zip(explicit_data.iter()).enumerate() {
prop_assert!(
(t - e).abs() < 1e-5,
"FALSIFIED TE-002-prop: tied[{}]={} != explicit[{}]={}",
i, t, i, e
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(30))]
#[test]
fn falsify_te_004_prop_finite(
token_ids in proptest::collection::vec(0_u32..999, 1..16),
) {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let logits = transformer.forward(&token_ids);
let data = logits.data();
for (i, &v) in data.iter().enumerate() {
prop_assert!(
v.is_finite(),
"FALSIFIED TE-004-prop: logits[{}]={} non-finite (n_tokens={})",
i, v, token_ids.len()
);
}
}
}
}
#[test]
fn falsify_pipe_001_embed_tied_softmax_pipeline() {
let config = TransformerConfig::tiny();
let transformer = Transformer::new(&config);
let tokens = vec![0u32, 3, 7, 15, 42];
let seq_len = tokens.len();
let vocab_size = config.vocab_size;
let logits = transformer.forward(&tokens);
let logits_data = logits.data();
assert_eq!(
logits_data.len(),
seq_len * vocab_size,
"FALSIFIED PIPE-001/TE-001: logits len={} != seq_len({seq_len}) * vocab({vocab_size})",
logits_data.len()
);
for (i, &l) in logits_data.iter().enumerate() {
assert!(l.is_finite(), "FALSIFIED PIPE-001/TE-004: logits[{i}] = {l} not finite");
}
let logits_slice = logits_data.as_slice().expect("operation should succeed");
for row in 0..seq_len {
let start = row * vocab_size;
let end = start + vocab_size;
let row_logits = &logits_slice[start..end];
let max_val = row_logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row_logits.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
let probs: Vec<f32> = exps.iter().map(|&e| e / sum).collect();
let prob_sum: f32 = probs.iter().sum();
assert!(
(prob_sum - 1.0).abs() < 1e-4,
"FALSIFIED PIPE-001/SM-001: row {row} prob sum={prob_sum}"
);
for (i, &p) in probs.iter().enumerate() {
assert!(p >= 0.0, "FALSIFIED PIPE-001/SM-002: row {row} prob[{i}]={p} negative");
}
let logit_argmax = row_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
.expect("operation should succeed")
.0;
let prob_argmax = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
.expect("operation should succeed")
.0;
assert_eq!(
logit_argmax, prob_argmax,
"FALSIFIED PIPE-001/SM-003: row {row} argmax changed {logit_argmax} → {prob_argmax}"
);
}
}
mod safetensors_tests {
use super::*;
use safetensors::serialize;
use safetensors::tensor::{Dtype, TensorView};
use tempfile::TempDir;
fn create_tiny_safetensors(dir: &std::path::Path) -> std::path::PathBuf {
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let mut tensors_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
let make_f32 = |n: usize, val: f32| -> Vec<u8> {
std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
};
tensors_data.push((
"model.embed_tokens.weight".to_string(),
make_f32(vocab * hidden, 0.01),
vec![vocab, hidden],
));
tensors_data.push((
"model.norm.weight".to_string(),
make_f32(hidden, 1.0),
vec![hidden],
));
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
tensors_data.push((
format!("{p}.input_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
tensors_data.push((
format!("{p}.post_attention_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
tensors_data.push((
format!("{p}.self_attn.q_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
tensors_data.push((
format!("{p}.self_attn.k_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
tensors_data.push((
format!("{p}.self_attn.v_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
tensors_data.push((
format!("{p}.self_attn.o_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
tensors_data.push((
format!("{p}.mlp.gate_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
tensors_data.push((
format!("{p}.mlp.up_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
tensors_data.push((
format!("{p}.mlp.down_proj.weight"),
make_f32(intermediate * hidden, 0.01),
vec![hidden, intermediate],
));
}
let views: Vec<TensorView<'_>> = tensors_data
.iter()
.map(|(_, bytes, shape)| {
TensorView::new(Dtype::F32, shape.clone(), bytes).expect("valid tensor view")
})
.collect();
let named_views: Vec<(&str, &TensorView<'_>)> = tensors_data
.iter()
.zip(views.iter())
.map(|((name, _, _), view)| (name.as_str(), view))
.collect();
let file_path = dir.join("model.safetensors");
let serialized =
serialize(named_views, None::<std::collections::HashMap<String, String>>)
.expect("serialize safetensors");
std::fs::write(&file_path, serialized).expect("write safetensors file");
file_path
}
fn create_tiny_bf16_safetensors(dir: &std::path::Path) -> std::path::PathBuf {
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let mut tensors_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
let make_bf16 = |n: usize, val: f32| -> Vec<u8> {
std::iter::repeat_n(half::bf16::from_f32(val), n)
.flat_map(half::bf16::to_le_bytes)
.collect()
};
tensors_data.push((
"model.embed_tokens.weight".to_string(),
make_bf16(vocab * hidden, 0.01),
vec![vocab, hidden],
));
tensors_data.push((
"model.norm.weight".to_string(),
make_bf16(hidden, 1.0),
vec![hidden],
));
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
tensors_data.push((
format!("{p}.input_layernorm.weight"),
make_bf16(hidden, 1.0),
vec![hidden],
));
tensors_data.push((
format!("{p}.post_attention_layernorm.weight"),
make_bf16(hidden, 1.0),
vec![hidden],
));
tensors_data.push((
format!("{p}.self_attn.q_proj.weight"),
make_bf16(hidden * hidden, 0.01),
vec![hidden, hidden],
));
tensors_data.push((
format!("{p}.self_attn.k_proj.weight"),
make_bf16(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
tensors_data.push((
format!("{p}.self_attn.v_proj.weight"),
make_bf16(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
tensors_data.push((
format!("{p}.self_attn.o_proj.weight"),
make_bf16(hidden * hidden, 0.01),
vec![hidden, hidden],
));
tensors_data.push((
format!("{p}.mlp.gate_proj.weight"),
make_bf16(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
tensors_data.push((
format!("{p}.mlp.up_proj.weight"),
make_bf16(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
tensors_data.push((
format!("{p}.mlp.down_proj.weight"),
make_bf16(intermediate * hidden, 0.01),
vec![hidden, intermediate],
));
}
let views: Vec<TensorView<'_>> = tensors_data
.iter()
.map(|(_, bytes, shape)| {
TensorView::new(Dtype::BF16, shape.clone(), bytes).expect("valid tensor view")
})
.collect();
let named_views: Vec<(&str, &TensorView<'_>)> = tensors_data
.iter()
.zip(views.iter())
.map(|((name, _, _), view)| (name.as_str(), view))
.collect();
let file_path = dir.join("model.safetensors");
let serialized =
serialize(named_views, None::<std::collections::HashMap<String, String>>)
.expect("serialize safetensors");
std::fs::write(&file_path, serialized).expect("write safetensors file");
file_path
}
#[test]
fn test_ssc024_from_safetensors_f32_success() {
let dir = TempDir::new().expect("create temp dir");
create_tiny_safetensors(dir.path());
let config = TransformerConfig::tiny();
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(
result.is_ok(),
"from_safetensors should succeed: {}",
result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
);
let transformer = result.expect("validated above");
assert_eq!(transformer.layers.len(), config.num_hidden_layers);
assert!(transformer.lm_head.is_none()); }
#[test]
fn test_ssc024_from_safetensors_bf16_conversion() {
let dir = TempDir::new().expect("create temp dir");
create_tiny_bf16_safetensors(dir.path());
let config = TransformerConfig::tiny();
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(
result.is_ok(),
"BF16 loading should succeed: {}",
result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
);
let transformer = result.expect("validated above");
assert_eq!(transformer.layers.len(), config.num_hidden_layers);
let tokens = vec![1u32, 2, 3];
let logits = transformer.forward(&tokens);
assert_eq!(logits.len(), 3 * config.vocab_size);
assert!(
logits.data().iter().all(|v| v.is_finite()),
"BF16-loaded model should produce finite outputs"
);
}
#[test]
fn test_ssc024_from_safetensors_single_file_path() {
let dir = TempDir::new().expect("create temp dir");
let file_path = create_tiny_safetensors(dir.path());
let config = TransformerConfig::tiny();
let result = Transformer::from_safetensors(&file_path, &config);
assert!(
result.is_ok(),
"Direct file path should work: {}",
result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
);
}
#[test]
fn test_ssc024_loaded_model_forward_produces_finite() {
let dir = TempDir::new().expect("create temp dir");
create_tiny_safetensors(dir.path());
let config = TransformerConfig::tiny();
let transformer =
Transformer::from_safetensors(dir.path(), &config).expect("loading should succeed");
let tokens = vec![0u32, 5, 42, 99];
let logits = transformer.forward(&tokens);
assert_eq!(logits.len(), tokens.len() * config.vocab_size);
let data = logits.data();
let nan_count = data.iter().filter(|v| v.is_nan()).count();
let inf_count = data.iter().filter(|v| v.is_infinite()).count();
assert_eq!(nan_count, 0, "Loaded model output must not contain NaN");
assert_eq!(inf_count, 0, "Loaded model output must not contain Inf");
}
#[test]
fn test_ssc024_from_safetensors_no_files() {
let dir = TempDir::new().expect("create temp dir");
let config = TransformerConfig::tiny();
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(result.is_err());
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(
err_msg.contains("No SafeTensors files"),
"Error should mention missing files: {err_msg}"
);
}
#[test]
fn test_ssc024_from_safetensors_wrong_embedding_shape() {
let dir = TempDir::new().expect("create temp dir");
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let wrong_embed_bytes: Vec<u8> =
std::iter::repeat_n(0.01_f32, 42).flat_map(f32::to_le_bytes).collect();
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
let make_f32 = |n: usize, val: f32| -> Vec<u8> {
std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
};
td.push(("model.embed_tokens.weight".to_string(), wrong_embed_bytes, vec![42]));
td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
td.push((
format!("{p}.input_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.post_attention_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.self_attn.q_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
td.push((
format!("{p}.self_attn.k_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.v_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.o_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
td.push((
format!("{p}.mlp.gate_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.up_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.down_proj.weight"),
make_f32(intermediate * hidden, 0.01),
vec![hidden, intermediate],
));
}
let views: Vec<TensorView<'_>> = td
.iter()
.map(|(_, bytes, shape)| {
TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
})
.collect();
let named: Vec<(&str, &TensorView<'_>)> =
td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
let file_path = dir.path().join("model.safetensors");
let serialized =
serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
std::fs::write(&file_path, serialized).expect("write");
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(result.is_err(), "Wrong embedding shape should fail");
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(
err_msg.contains("Shape mismatch") || err_msg.contains("embed_tokens"),
"Error should indicate shape issue: {err_msg}"
);
}
#[test]
fn test_ssc024_from_safetensors_nan_detection() {
let dir = TempDir::new().expect("create temp dir");
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
let make_f32 = |n: usize, val: f32| -> Vec<u8> {
std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
};
let mut embed_vals: Vec<f32> = vec![0.01; vocab * hidden];
embed_vals[42] = f32::NAN;
let embed_bytes: Vec<u8> = embed_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
td.push(("model.embed_tokens.weight".to_string(), embed_bytes, vec![vocab, hidden]));
td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
td.push((
format!("{p}.input_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.post_attention_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.self_attn.q_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
td.push((
format!("{p}.self_attn.k_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.v_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.o_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
td.push((
format!("{p}.mlp.gate_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.up_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.down_proj.weight"),
make_f32(intermediate * hidden, 0.01),
vec![hidden, intermediate],
));
}
let views: Vec<TensorView<'_>> = td
.iter()
.map(|(_, bytes, shape)| {
TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
})
.collect();
let named: Vec<(&str, &TensorView<'_>)> =
td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
let file_path = dir.path().join("model.safetensors");
let serialized =
serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
std::fs::write(&file_path, serialized).expect("write");
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(result.is_err(), "NaN in weights should fail");
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(err_msg.contains("NaN"), "Error should mention NaN: {err_msg}");
}
#[test]
fn test_ssc024_from_safetensors_inf_detection() {
let dir = TempDir::new().expect("create temp dir");
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
let make_f32 = |n: usize, val: f32| -> Vec<u8> {
std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
};
let mut norm_vals: Vec<f32> = vec![1.0; hidden];
norm_vals[0] = f32::INFINITY;
let norm_bytes: Vec<u8> = norm_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
td.push((
"model.embed_tokens.weight".to_string(),
make_f32(vocab * hidden, 0.01),
vec![vocab, hidden],
));
td.push(("model.norm.weight".to_string(), norm_bytes, vec![hidden]));
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
td.push((
format!("{p}.input_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.post_attention_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.self_attn.q_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
td.push((
format!("{p}.self_attn.k_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.v_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.o_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
td.push((
format!("{p}.mlp.gate_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.up_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.down_proj.weight"),
make_f32(intermediate * hidden, 0.01),
vec![hidden, intermediate],
));
}
let views: Vec<TensorView<'_>> = td
.iter()
.map(|(_, bytes, shape)| {
TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
})
.collect();
let named: Vec<(&str, &TensorView<'_>)> =
td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
let file_path = dir.path().join("model.safetensors");
let serialized =
serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
std::fs::write(&file_path, serialized).expect("write");
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(result.is_err(), "Inf in weights should fail");
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(err_msg.contains("Inf"), "Error should mention Inf: {err_msg}");
}
#[test]
fn test_ssc024_from_safetensors_missing_layer() {
let dir = TempDir::new().expect("create temp dir");
create_tiny_safetensors(dir.path());
let mut config = TransformerConfig::tiny();
config.num_hidden_layers = 3;
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(result.is_err(), "Missing layer 2 should fail");
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(
err_msg.contains("Missing") || err_msg.contains("layers.2"),
"Error should mention missing layer: {err_msg}"
);
}
#[test]
fn test_ssc024_from_safetensors_wrong_q_proj_shape() {
let dir = TempDir::new().expect("create temp dir");
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let q_dim = config.q_dim();
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
let make_f32 = |n: usize, val: f32| -> Vec<u8> {
std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
};
td.push((
"model.embed_tokens.weight".to_string(),
make_f32(vocab * hidden, 0.01),
vec![vocab, hidden],
));
td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
td.push((
format!("{p}.input_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.post_attention_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
if i == 0 {
td.push((format!("{p}.self_attn.q_proj.weight"), make_f32(7, 0.01), vec![7]));
} else {
td.push((
format!("{p}.self_attn.q_proj.weight"),
make_f32(q_dim * hidden, 0.01),
vec![q_dim, hidden],
));
}
td.push((
format!("{p}.self_attn.k_proj.weight"),
make_f32(kv_hidden * hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.v_proj.weight"),
make_f32(kv_hidden * hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.o_proj.weight"),
make_f32(hidden * q_dim, 0.01),
vec![hidden, q_dim],
));
td.push((
format!("{p}.mlp.gate_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.up_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.down_proj.weight"),
make_f32(intermediate * hidden, 0.01),
vec![hidden, intermediate],
));
}
let views: Vec<TensorView<'_>> = td
.iter()
.map(|(_, bytes, shape)| {
TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
})
.collect();
let named: Vec<(&str, &TensorView<'_>)> =
td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
let file_path = dir.path().join("model.safetensors");
let serialized =
serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
std::fs::write(&file_path, serialized).expect("write");
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(result.is_err(), "Wrong q_proj shape should fail");
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(
err_msg.contains("Shape mismatch") && err_msg.contains("q_proj"),
"Error should mention q_proj shape mismatch: {err_msg}"
);
}
#[test]
fn test_ssc024_validate_weight_shapes_success() {
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let mut weights = HashMap::new();
weights.insert(
"model.embed_tokens.weight".to_string(),
Tensor::from_vec(vec![0.1; vocab * hidden], true),
);
weights
.insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
weights.insert(
format!("{p}.input_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden], true),
);
weights.insert(
format!("{p}.post_attention_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden], true),
);
weights.insert(
format!("{p}.self_attn.q_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * hidden], true),
);
weights.insert(
format!("{p}.self_attn.k_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * kv_hidden], true),
);
weights.insert(
format!("{p}.self_attn.v_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * kv_hidden], true),
);
weights.insert(
format!("{p}.self_attn.o_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * hidden], true),
);
weights.insert(
format!("{p}.mlp.gate_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * intermediate], true),
);
weights.insert(
format!("{p}.mlp.up_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * intermediate], true),
);
weights.insert(
format!("{p}.mlp.down_proj.weight"),
Tensor::from_vec(vec![0.1; intermediate * hidden], true),
);
}
let result = Transformer::validate_weight_shapes(&weights, &config);
assert!(
result.is_ok(),
"Valid shapes should pass: {}",
result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
);
}
#[test]
fn test_ssc024_validate_weight_shapes_wrong_norm() {
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let vocab = config.vocab_size;
let mut weights = HashMap::new();
weights.insert(
"model.embed_tokens.weight".to_string(),
Tensor::from_vec(vec![0.1; vocab * hidden], true),
);
weights.insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; 3], true));
let result = Transformer::validate_weight_shapes(&weights, &config);
assert!(result.is_err());
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(err_msg.contains("model.norm.weight"));
}
#[test]
fn test_ssc024_validate_weight_values_clean() {
let mut weights = HashMap::new();
weights.insert("a".to_string(), Tensor::from_vec(vec![0.1, 0.2, 0.3], true));
weights.insert("b".to_string(), Tensor::from_vec(vec![1.0, -1.0, 0.0], true));
let result = Transformer::validate_weight_values(&weights);
assert!(result.is_ok());
}
#[test]
fn test_ssc024_validate_weight_values_nan() {
let mut weights = HashMap::new();
weights.insert("clean".to_string(), Tensor::from_vec(vec![0.1, 0.2], true));
weights
.insert("poisoned".to_string(), Tensor::from_vec(vec![0.1, f32::NAN, 0.3], true));
let result = Transformer::validate_weight_values(&weights);
assert!(result.is_err());
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(err_msg.contains("NaN"));
assert!(err_msg.contains("poisoned"));
}
#[test]
fn test_ssc024_validate_weight_values_inf() {
let mut weights = HashMap::new();
weights.insert("w".to_string(), Tensor::from_vec(vec![f32::NEG_INFINITY, 0.2], true));
let result = Transformer::validate_weight_values(&weights);
assert!(result.is_err());
let err_msg = match result {
Err(e) => e.to_string(),
Ok(_) => panic!("expected error"),
};
assert!(err_msg.contains("Inf"));
}
#[test]
fn test_gh262_qwen3_4b_weight_shapes_q_dim_ne_hidden() {
let config = TransformerConfig {
hidden_size: 80,
num_attention_heads: 4,
num_kv_heads: 2,
intermediate_size: 128,
num_hidden_layers: 1,
vocab_size: 256,
max_position_embeddings: 512,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
use_bias: false,
head_dim_override: Some(32), architecture: crate::transformer::config::ModelArchitecture::Decoder,
hf_architecture: None,
hf_model_type: None,
tie_word_embeddings: false,
};
let hidden = config.hidden_size; let q_dim = config.q_dim(); let kv_hidden = config.num_kv_heads * config.head_dim(); let intermediate = config.intermediate_size; let vocab = config.vocab_size;
assert_ne!(q_dim, hidden, "test requires q_dim != hidden_size");
let mut weights = HashMap::new();
weights.insert(
"model.embed_tokens.weight".to_string(),
Tensor::from_vec(vec![0.1; vocab * hidden], true),
);
weights
.insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
let p = "model.layers.0";
weights.insert(
format!("{p}.input_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden], true),
);
weights.insert(
format!("{p}.post_attention_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden], true),
);
weights.insert(
format!("{p}.self_attn.q_proj.weight"),
Tensor::from_vec(vec![0.1; q_dim * hidden], true),
);
weights.insert(
format!("{p}.self_attn.k_proj.weight"),
Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
);
weights.insert(
format!("{p}.self_attn.v_proj.weight"),
Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
);
weights.insert(
format!("{p}.self_attn.o_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * q_dim], true),
);
weights.insert(
format!("{p}.mlp.gate_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * intermediate], true),
);
weights.insert(
format!("{p}.mlp.up_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * intermediate], true),
);
weights.insert(
format!("{p}.mlp.down_proj.weight"),
Tensor::from_vec(vec![0.1; intermediate * hidden], true),
);
let result = Transformer::validate_weight_shapes(&weights, &config);
assert!(
result.is_ok(),
"Qwen3-like shapes (q_dim={q_dim} != hidden={hidden}) should validate: {:?}",
result.err()
);
let model = Transformer::from_params(&config, &weights);
assert!(model.is_some(), "Qwen3-like model with q_dim != hidden should construct");
}
#[test]
fn test_gh262_wrong_q_proj_size_hidden_instead_of_q_dim() {
let config = TransformerConfig {
hidden_size: 80,
num_attention_heads: 4,
num_kv_heads: 2,
intermediate_size: 128,
num_hidden_layers: 1,
vocab_size: 256,
max_position_embeddings: 512,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
use_bias: false,
head_dim_override: Some(32), architecture: crate::transformer::config::ModelArchitecture::Decoder,
hf_architecture: None,
hf_model_type: None,
tie_word_embeddings: false,
};
let hidden = config.hidden_size; let kv_hidden = config.num_kv_heads * config.head_dim(); let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let mut weights = HashMap::new();
weights.insert(
"model.embed_tokens.weight".to_string(),
Tensor::from_vec(vec![0.1; vocab * hidden], true),
);
weights
.insert("model.norm.weight".to_string(), Tensor::from_vec(vec![1.0; hidden], true));
let p = "model.layers.0";
weights.insert(
format!("{p}.input_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden], true),
);
weights.insert(
format!("{p}.post_attention_layernorm.weight"),
Tensor::from_vec(vec![1.0; hidden], true),
);
weights.insert(
format!("{p}.self_attn.q_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * hidden], true),
);
weights.insert(
format!("{p}.self_attn.k_proj.weight"),
Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
);
weights.insert(
format!("{p}.self_attn.v_proj.weight"),
Tensor::from_vec(vec![0.1; kv_hidden * hidden], true),
);
weights.insert(
format!("{p}.self_attn.o_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * hidden], true),
);
weights.insert(
format!("{p}.mlp.gate_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * intermediate], true),
);
weights.insert(
format!("{p}.mlp.up_proj.weight"),
Tensor::from_vec(vec![0.1; hidden * intermediate], true),
);
weights.insert(
format!("{p}.mlp.down_proj.weight"),
Tensor::from_vec(vec![0.1; intermediate * hidden], true),
);
let result = Transformer::validate_weight_shapes(&weights, &config);
assert!(result.is_err(), "hidden*hidden q_proj should fail when q_dim != hidden");
let err_msg = result.err().map(|e| e.to_string()).unwrap_or_default();
assert!(
err_msg.contains("q_proj") && err_msg.contains("Shape mismatch"),
"Error should mention q_proj shape mismatch, got: {err_msg}"
);
}
#[test]
fn test_ssc024_from_safetensors_with_extra_bias_tensors() {
let dir = TempDir::new().expect("create temp dir");
let config = TransformerConfig::tiny();
let hidden = config.hidden_size;
let kv_hidden = config.num_kv_heads * config.head_dim();
let intermediate = config.intermediate_size;
let vocab = config.vocab_size;
let mut td: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
let make_f32 = |n: usize, val: f32| -> Vec<u8> {
std::iter::repeat_n(val, n).flat_map(f32::to_le_bytes).collect()
};
td.push((
"model.embed_tokens.weight".to_string(),
make_f32(vocab * hidden, 0.01),
vec![vocab, hidden],
));
td.push(("model.norm.weight".to_string(), make_f32(hidden, 1.0), vec![hidden]));
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{i}");
td.push((
format!("{p}.input_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.post_attention_layernorm.weight"),
make_f32(hidden, 1.0),
vec![hidden],
));
td.push((
format!("{p}.self_attn.q_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
td.push((
format!("{p}.self_attn.k_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.v_proj.weight"),
make_f32(hidden * kv_hidden, 0.01),
vec![kv_hidden, hidden],
));
td.push((
format!("{p}.self_attn.o_proj.weight"),
make_f32(hidden * hidden, 0.01),
vec![hidden, hidden],
));
td.push((
format!("{p}.mlp.gate_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.up_proj.weight"),
make_f32(hidden * intermediate, 0.01),
vec![intermediate, hidden],
));
td.push((
format!("{p}.mlp.down_proj.weight"),
make_f32(intermediate * hidden, 0.01),
vec![hidden, intermediate],
));
td.push((
format!("{p}.self_attn.q_proj.bias"),
make_f32(hidden, 0.0),
vec![hidden],
));
td.push((
format!("{p}.self_attn.k_proj.bias"),
make_f32(kv_hidden, 0.0),
vec![kv_hidden],
));
td.push((
format!("{p}.self_attn.v_proj.bias"),
make_f32(kv_hidden, 0.0),
vec![kv_hidden],
));
}
let views: Vec<TensorView<'_>> = td
.iter()
.map(|(_, bytes, shape)| {
TensorView::new(Dtype::F32, shape.clone(), bytes).expect("view")
})
.collect();
let named: Vec<(&str, &TensorView<'_>)> =
td.iter().zip(views.iter()).map(|((n, _, _), v)| (n.as_str(), v)).collect();
let file_path = dir.path().join("model.safetensors");
let serialized =
serialize(named, None::<std::collections::HashMap<String, String>>).expect("ser");
std::fs::write(&file_path, serialized).expect("write");
let result = Transformer::from_safetensors(dir.path(), &config);
assert!(
result.is_ok(),
"Extra bias tensors should not cause failure: {}",
result.as_ref().err().map_or(String::new(), std::string::ToString::to_string)
);
}
}
}