use std::fmt;
use std::io::Read as _;
use std::path::Path;
use serde_json::Value;
use crate::error::{MIError, Result};
pub const SUPPORTED_MODEL_TYPES: &[&str] = &[
"gemma",
"gemma2",
"llama",
"mistral",
"phi3",
"qwen2",
"starcoder2",
];
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormType {
RmsNorm,
LayerNorm,
GemmaRmsNorm,
}
impl fmt::Display for NormType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::RmsNorm => write!(f, "RmsNorm"),
Self::LayerNorm => write!(f, "LayerNorm"),
Self::GemmaRmsNorm => write!(f, "GemmaRmsNorm"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Activation {
Silu,
Gelu,
GeluApprox,
}
impl fmt::Display for Activation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Silu => write!(f, "SiLU"),
Self::Gelu => write!(f, "GELU"),
Self::GeluApprox => write!(f, "GELU (tanh approx)"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QkvLayout {
Separate,
Fused,
}
impl fmt::Display for QkvLayout {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Separate => write!(f, "Separate"),
Self::Fused => write!(f, "Fused"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MlpLayout {
GatedSeparate,
GatedFused,
Plain,
}
impl fmt::Display for MlpLayout {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::GatedSeparate => write!(f, "GatedSeparate"),
Self::GatedFused => write!(f, "GatedFused"),
Self::Plain => write!(f, "Plain"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[allow(clippy::struct_excessive_bools)] pub struct TransformerConfig {
pub hidden_size: usize,
pub num_layers: usize,
pub num_attention_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub norm_type: NormType,
pub norm_eps: f64,
pub activation: Activation,
pub qkv_layout: QkvLayout,
pub mlp_layout: MlpLayout,
pub qkv_bias: bool,
pub o_proj_bias: bool,
pub mlp_bias: bool,
pub embedding_scale: Option<f64>,
pub tie_word_embeddings: bool,
pub rope_theta: f64,
pub max_position_embeddings: usize,
pub attn_logit_softcapping: Option<f64>,
pub final_logit_softcapping: Option<f64>,
pub query_pre_attn_scalar: Option<f64>,
pub use_post_norms: bool,
pub sliding_window: Option<usize>,
pub alternating_sliding_window: bool,
}
impl TransformerConfig {
pub fn from_hf_config(config: &Value) -> Result<Self> {
let model_type = config
.get("model_type")
.and_then(Value::as_str)
.ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
match model_type {
"llama" => Self::parse_llama(config),
"qwen2" => Self::parse_qwen2(config),
"gemma" => Self::parse_gemma(config),
"gemma2" => Self::parse_gemma2(config),
"phi3" => Self::parse_phi3(config),
"starcoder2" => Self::parse_starcoder2(config),
"mistral" => Self::parse_mistral(config),
other => Err(MIError::Config(format!(
"unsupported model_type: '{other}'"
))),
}
}
}
impl TransformerConfig {
fn parse_llama(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let num_attention_heads = get_usize(config, "num_attention_heads")?;
Ok(Self {
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
num_attention_heads,
num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
intermediate_size: get_usize(config, "intermediate_size")?,
vocab_size: get_usize(config, "vocab_size")?,
norm_type: NormType::RmsNorm,
norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
activation: Activation::Silu,
qkv_layout: QkvLayout::Separate,
mlp_layout: MlpLayout::GatedSeparate,
qkv_bias: false,
o_proj_bias: false,
mlp_bias: false,
embedding_scale: None,
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
attn_logit_softcapping: None,
final_logit_softcapping: None,
query_pre_attn_scalar: None,
use_post_norms: false,
sliding_window: None,
alternating_sliding_window: false,
})
}
fn parse_qwen2(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let num_attention_heads = get_usize(config, "num_attention_heads")?;
Ok(Self {
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
num_attention_heads,
num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
intermediate_size: get_usize(config, "intermediate_size")?,
vocab_size: get_usize(config, "vocab_size")?,
norm_type: NormType::RmsNorm,
norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
activation: Activation::Silu,
qkv_layout: QkvLayout::Separate,
mlp_layout: MlpLayout::GatedSeparate,
qkv_bias: get_bool_or(config, "attention_bias", true),
o_proj_bias: false,
mlp_bias: false,
embedding_scale: None,
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
rope_theta: get_f64_or(config, "rope_theta", 1_000_000.0),
max_position_embeddings: get_usize_or(config, "max_position_embeddings", 32_768),
attn_logit_softcapping: None,
final_logit_softcapping: None,
query_pre_attn_scalar: None,
use_post_norms: false,
sliding_window: None,
alternating_sliding_window: false,
})
}
fn parse_gemma(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let num_attention_heads = get_usize(config, "num_attention_heads")?;
Ok(Self {
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
num_attention_heads,
num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
intermediate_size: get_usize(config, "intermediate_size")?,
vocab_size: get_usize(config, "vocab_size")?,
norm_type: NormType::GemmaRmsNorm,
norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
activation: Activation::GeluApprox,
qkv_layout: QkvLayout::Separate,
mlp_layout: MlpLayout::GatedSeparate,
qkv_bias: false,
o_proj_bias: false,
mlp_bias: false,
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
embedding_scale: Some((hidden_size as f64).sqrt()),
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
max_position_embeddings: get_usize_or(
config,
"max_position_embeddings",
8192,
),
attn_logit_softcapping: None,
final_logit_softcapping: None,
query_pre_attn_scalar: None,
use_post_norms: false,
sliding_window: None,
alternating_sliding_window: false,
})
}
fn parse_gemma2(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let num_attention_heads = get_usize(config, "num_attention_heads")?;
Ok(Self {
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
num_attention_heads,
num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
intermediate_size: get_usize(config, "intermediate_size")?,
vocab_size: get_usize(config, "vocab_size")?,
norm_type: NormType::GemmaRmsNorm,
norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
activation: Activation::GeluApprox,
qkv_layout: QkvLayout::Separate,
mlp_layout: MlpLayout::GatedSeparate,
qkv_bias: false,
o_proj_bias: false,
mlp_bias: false,
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
embedding_scale: Some((hidden_size as f64).sqrt()),
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
max_position_embeddings: get_usize_or(
config,
"max_position_embeddings",
8192,
),
attn_logit_softcapping: get_optional_f64(config, "attn_logit_softcapping"),
final_logit_softcapping: get_optional_f64(config, "final_logit_softcapping"),
query_pre_attn_scalar: get_optional_f64(config, "query_pre_attn_scalar")
.or(Some(256.0)),
use_post_norms: true,
sliding_window: get_optional_usize(config, "sliding_window"),
alternating_sliding_window: true,
})
}
fn parse_phi3(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let num_attention_heads = get_usize(config, "num_attention_heads")?;
Ok(Self {
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
num_attention_heads,
num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
intermediate_size: get_usize(config, "intermediate_size")?,
vocab_size: get_usize(config, "vocab_size")?,
norm_type: NormType::RmsNorm,
norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
activation: Activation::Silu,
qkv_layout: QkvLayout::Fused,
mlp_layout: MlpLayout::GatedFused,
qkv_bias: false,
o_proj_bias: false,
mlp_bias: false,
embedding_scale: None,
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
attn_logit_softcapping: None,
final_logit_softcapping: None,
query_pre_attn_scalar: None,
use_post_norms: false,
sliding_window: None,
alternating_sliding_window: false,
})
}
fn parse_starcoder2(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let num_attention_heads = get_usize(config, "num_attention_heads")?;
let use_bias = get_bool_or(config, "use_bias", true);
let norm_type = match config.get("norm_type").and_then(Value::as_str) {
Some("layer_norm") => NormType::LayerNorm,
_ => NormType::RmsNorm,
};
Ok(Self {
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
num_attention_heads,
num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
intermediate_size: get_usize(config, "intermediate_size")?,
vocab_size: get_usize(config, "vocab_size")?,
norm_type,
norm_eps: get_f64_or(config, "norm_epsilon", 1e-5),
activation: Activation::GeluApprox,
qkv_layout: QkvLayout::Separate,
mlp_layout: MlpLayout::Plain,
qkv_bias: use_bias,
o_proj_bias: use_bias,
mlp_bias: use_bias,
embedding_scale: None,
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
max_position_embeddings: get_usize_or(config, "max_position_embeddings", 16_384),
attn_logit_softcapping: None,
final_logit_softcapping: None,
query_pre_attn_scalar: None,
use_post_norms: false,
sliding_window: get_optional_usize(config, "sliding_window"),
alternating_sliding_window: false,
})
}
fn parse_mistral(config: &Value) -> Result<Self> {
let hidden_size = get_usize(config, "hidden_size")?;
let num_attention_heads = get_usize(config, "num_attention_heads")?;
Ok(Self {
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
num_attention_heads,
num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
intermediate_size: get_usize(config, "intermediate_size")?,
vocab_size: get_usize(config, "vocab_size")?,
norm_type: NormType::RmsNorm,
norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
activation: Activation::Silu,
qkv_layout: QkvLayout::Separate,
mlp_layout: MlpLayout::GatedSeparate,
qkv_bias: false,
o_proj_bias: false,
mlp_bias: false,
embedding_scale: None,
tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
max_position_embeddings: get_usize_or(config, "max_position_embeddings", 32_768),
attn_logit_softcapping: None,
final_logit_softcapping: None,
query_pre_attn_scalar: None,
use_post_norms: false,
sliding_window: get_optional_usize(config, "sliding_window"),
alternating_sliding_window: false,
})
}
}
pub(crate) fn get_usize(config: &Value, key: &str) -> Result<usize> {
let val = config
.get(key)
.and_then(Value::as_u64)
.ok_or_else(|| MIError::Config(format!("missing or invalid field '{key}'")))?;
usize::try_from(val)
.map_err(|_| MIError::Config(format!("field '{key}' value {val} overflows usize")))
}
pub(crate) fn get_usize_or(config: &Value, key: &str, default: usize) -> usize {
config
.get(key)
.and_then(Value::as_u64)
.and_then(|v| usize::try_from(v).ok())
.unwrap_or(default)
}
pub(crate) fn get_optional_usize(config: &Value, key: &str) -> Option<usize> {
config
.get(key)
.and_then(Value::as_u64)
.and_then(|v| usize::try_from(v).ok())
}
pub(crate) fn get_f64_or(config: &Value, key: &str, default: f64) -> f64 {
config.get(key).and_then(Value::as_f64).unwrap_or(default)
}
pub(crate) fn get_optional_f64(config: &Value, key: &str) -> Option<f64> {
config.get(key).and_then(Value::as_f64)
}
pub(crate) fn get_bool_or(config: &Value, key: &str, default: bool) -> bool {
config.get(key).and_then(Value::as_bool).unwrap_or(default)
}
pub(crate) fn get_head_dim(
config: &Value,
hidden_size: usize,
num_attention_heads: usize,
) -> Result<usize> {
let explicit = config.get("head_dim").and_then(Value::as_u64).map(|hd| {
usize::try_from(hd).map_err(|_| MIError::Config("head_dim overflows usize".into()))
});
match explicit {
Some(result) => result,
None if num_attention_heads == 0 => Err(MIError::Config(
"num_attention_heads is 0, cannot compute head_dim".into(),
)),
None => Ok(hidden_size / num_attention_heads),
}
}
fn parse_activation_str(config: &Value) -> Activation {
let act_str = config
.get("hidden_activation")
.or_else(|| config.get("hidden_act"))
.and_then(Value::as_str);
match act_str {
Some("gelu_pytorch_tanh") => Activation::GeluApprox,
Some("gelu") => Activation::Gelu,
_ => Activation::Silu,
}
}
pub fn tensor_names_from_safetensors(path: &Path) -> Result<Vec<String>> {
let mut file = std::fs::File::open(path)?;
let mut len_buf = [0u8; 8];
file.read_exact(&mut len_buf)?;
let header_len = u64::from_le_bytes(len_buf);
let header_len = usize::try_from(header_len)
.map_err(|_| MIError::Config("safetensors header length overflows usize".into()))?;
let mut header_buf = vec![0u8; header_len];
file.read_exact(&mut header_buf)?;
let header: Value = serde_json::from_slice(&header_buf)
.map_err(|e| MIError::Config(format!("failed to parse safetensors header: {e}")))?;
let obj = header
.as_object()
.ok_or_else(|| MIError::Config("safetensors header is not a JSON object".into()))?;
Ok(obj
.keys()
.filter(|k| *k != "__metadata__")
.cloned()
.collect())
}
pub fn tensor_names_from_index(path: &Path) -> Result<Vec<String>> {
let content = std::fs::read_to_string(path)?;
let index: Value = serde_json::from_str(&content)
.map_err(|e| MIError::Config(format!("failed to parse safetensors index: {e}")))?;
let weight_map = index
.get("weight_map")
.and_then(Value::as_object)
.ok_or_else(|| MIError::Config("missing 'weight_map' in safetensors index".into()))?;
Ok(weight_map.keys().cloned().collect())
}
impl TransformerConfig {
pub fn from_hf_config_auto(config: &Value, tensor_names: &[String]) -> Result<Self> {
let model_type = config
.get("model_type")
.and_then(Value::as_str)
.ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
if SUPPORTED_MODEL_TYPES.contains(&model_type) {
return Self::from_hf_config(config);
}
Self::parse_auto(config, tensor_names, model_type)
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
fn parse_auto(config: &Value, tensor_names: &[String], model_type: &str) -> Result<Self> {
let has_layer0 = |suffix: &str| {
tensor_names
.iter()
.any(|n| n.contains("layers.0.") && n.ends_with(suffix))
};
let hidden_size = get_usize(config, "hidden_size")?;
let num_attention_heads = get_usize(config, "num_attention_heads")?;
let norm_eps = config
.get("rms_norm_eps")
.and_then(Value::as_f64)
.or_else(|| config.get("norm_epsilon").and_then(Value::as_f64))
.unwrap_or(1e-5);
let activation = parse_activation_str(config);
let sliding_window =
if config.get("use_sliding_window").and_then(Value::as_bool) == Some(false) {
None
} else {
get_optional_usize(config, "sliding_window")
};
let tie_word_embeddings = config
.get("tie_word_embeddings")
.and_then(Value::as_bool)
.unwrap_or_else(|| !tensor_names.iter().any(|n| n == "lm_head.weight"));
let attn_logit_softcapping = get_optional_f64(config, "attn_logit_softcapping");
let final_logit_softcapping = get_optional_f64(config, "final_logit_softcapping");
let query_pre_attn_scalar = get_optional_f64(config, "query_pre_attn_scalar");
let qkv_layout = if has_layer0("self_attn.qkv_proj.weight") {
QkvLayout::Fused
} else {
QkvLayout::Separate
};
let mlp_layout = if has_layer0("mlp.gate_up_proj.weight") {
MlpLayout::GatedFused
} else if has_layer0("mlp.gate_proj.weight") {
MlpLayout::GatedSeparate
} else if has_layer0("mlp.c_fc.weight") {
MlpLayout::Plain
} else {
MlpLayout::GatedSeparate };
let qkv_bias = has_layer0("self_attn.q_proj.bias") || has_layer0("self_attn.qkv_proj.bias");
let o_proj_bias = has_layer0("self_attn.o_proj.bias");
let mlp_bias = has_layer0("mlp.down_proj.bias")
|| has_layer0("mlp.c_fc.bias")
|| has_layer0("mlp.gate_proj.bias")
|| has_layer0("mlp.gate_up_proj.bias");
let has_norm_bias = has_layer0("input_layernorm.bias");
let base_norm_type = if has_norm_bias {
NormType::LayerNorm
} else {
NormType::RmsNorm
};
let use_post_norms = has_layer0("post_feedforward_layernorm.weight")
|| has_layer0("pre_feedforward_layernorm.weight");
let is_gemma = model_type.contains("gemma");
let norm_type = if is_gemma {
NormType::GemmaRmsNorm
} else {
base_norm_type
};
let embedding_scale = if is_gemma {
Some((hidden_size as f64).sqrt())
} else {
None
};
let alternating_sliding_window = is_gemma && use_post_norms;
let query_pre_attn_scalar = if is_gemma && use_post_norms {
query_pre_attn_scalar.or(Some(256.0))
} else {
query_pre_attn_scalar
};
Ok(Self {
hidden_size,
num_layers: get_usize(config, "num_hidden_layers")?,
num_attention_heads,
num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
intermediate_size: get_usize(config, "intermediate_size")?,
vocab_size: get_usize(config, "vocab_size")?,
norm_type,
norm_eps,
activation,
qkv_layout,
mlp_layout,
qkv_bias,
o_proj_bias,
mlp_bias,
embedding_scale,
tie_word_embeddings,
rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
attn_logit_softcapping,
final_logit_softcapping,
query_pre_attn_scalar,
use_post_norms,
sliding_window,
alternating_sliding_window,
})
}
}
#[derive(Debug, Clone)]
pub struct CompatibilityReport {
pub compatible: bool,
pub issues: Vec<String>,
}
impl CompatibilityReport {
pub fn into_result(self) -> Result<()> {
if self.compatible {
Ok(())
} else {
Err(MIError::Config(format!(
"model is not compatible with GenericTransformer:\n - {}",
self.issues.join("\n - ")
)))
}
}
}
impl TransformerConfig {
#[must_use]
pub fn check_config_fields(config: &Value) -> CompatibilityReport {
let required = [
"hidden_size",
"num_hidden_layers",
"num_attention_heads",
"intermediate_size",
"vocab_size",
];
let mut issues = Vec::new();
for key in &required {
if config.get(*key).and_then(Value::as_u64).is_none() {
issues.push(format!("missing or invalid required field '{key}'"));
}
}
CompatibilityReport {
compatible: issues.is_empty(),
issues,
}
}
#[must_use]
pub fn check_auto_compatibility(
config: &Value,
tensor_names: &[String],
) -> CompatibilityReport {
let mut issues = Vec::new();
let field_report = Self::check_config_fields(config);
issues.extend(field_report.issues);
let has_tensor_issues = check_tensor_names(config, tensor_names, &mut issues);
if has_tensor_issues
&& !tensor_names.is_empty()
&& let Some(hint) = detect_naming_convention(tensor_names)
{
issues.push(hint);
}
CompatibilityReport {
compatible: issues.is_empty(),
issues,
}
}
}
#[allow(clippy::too_many_lines)]
fn check_tensor_names(config: &Value, tensor_names: &[String], issues: &mut Vec<String>) -> bool {
let has = |name: &str| tensor_names.iter().any(|n| n == name);
let has_layer0 = |suffix: &str| {
tensor_names
.iter()
.any(|n| n.contains("layers.0.") && n.ends_with(suffix))
};
let find_matching = |keyword: &str, limit: usize| -> Vec<&str> {
tensor_names
.iter()
.filter(|n| n.to_lowercase().contains(keyword))
.take(limit)
.map(String::as_str)
.collect::<Vec<_>>()
};
let mut has_issues = false;
if !has("model.embed_tokens.weight") {
has_issues = true;
let found: Vec<&str> = tensor_names
.iter()
.filter(|n| n.contains("embed") || n.contains("wte") || n.contains("word_embeddings"))
.take(3)
.map(String::as_str)
.collect();
let hint = if found.is_empty() {
String::new()
} else {
format!("; found embedding-like tensors: {}", found.join(", "))
};
issues.push(format!(
"missing embedding tensor 'model.embed_tokens.weight'{hint}"
));
}
if !has_layer0("input_layernorm.weight") {
has_issues = true;
let found = find_matching("norm", 4);
let hint = if found.is_empty() {
String::new()
} else {
format!("; found norm-like tensors: {}", found.join(", "))
};
issues.push(format!(
"missing normalization tensor \
'model.layers.0.input_layernorm.weight'{hint}"
));
}
if !has_layer0("post_attention_layernorm.weight")
&& !has_layer0("pre_feedforward_layernorm.weight")
{
has_issues = true;
issues.push(
"missing normalization tensor \
'model.layers.0.post_attention_layernorm.weight'"
.into(),
);
}
if !has("model.norm.weight") {
has_issues = true;
let found: Vec<&str> = tensor_names
.iter()
.filter(|n| {
(n.contains("ln_f") || n.contains("final_layer_norm") || n.contains("ln_out"))
&& n.ends_with(".weight")
})
.take(2)
.map(String::as_str)
.collect();
let hint = if found.is_empty() {
String::new()
} else {
format!("; found final-norm-like tensors: {}", found.join(", "))
};
issues.push(format!(
"missing final norm tensor 'model.norm.weight'{hint}"
));
}
let has_separate_attn = has_layer0("self_attn.q_proj.weight");
let has_fused_attn = has_layer0("self_attn.qkv_proj.weight");
if !has_separate_attn && !has_fused_attn {
has_issues = true;
let found = find_matching("attn", 4);
let hint = if found.is_empty() {
String::new()
} else {
format!("; found attention-like tensors: {}", found.join(", "))
};
issues.push(format!(
"missing attention projections: expected \
'self_attn.q_proj.weight' or 'self_attn.qkv_proj.weight'{hint}"
));
}
let has_gated_separate = has_layer0("mlp.gate_proj.weight");
let has_gated_fused = has_layer0("mlp.gate_up_proj.weight");
let has_plain = has_layer0("mlp.c_fc.weight");
let has_down = has_layer0("mlp.down_proj.weight");
if !has_gated_separate && !has_gated_fused && !has_plain && !has_down {
has_issues = true;
let found: Vec<&str> = tensor_names
.iter()
.filter(|n| n.contains("mlp") || n.contains("ffn") || n.contains("fc"))
.take(4)
.map(String::as_str)
.collect();
let hint = if found.is_empty() {
String::new()
} else {
format!("; found MLP-like tensors: {}", found.join(", "))
};
issues.push(format!(
"missing MLP projections: expected 'mlp.gate_proj.weight', \
'mlp.gate_up_proj.weight', or 'mlp.c_fc.weight'{hint}"
));
}
let tie = config
.get("tie_word_embeddings")
.and_then(Value::as_bool)
.unwrap_or_else(|| !tensor_names.iter().any(|n| n == "lm_head.weight"));
if !tie && !has("lm_head.weight") {
issues.push("tie_word_embeddings is false but 'lm_head.weight' tensor is missing".into());
}
has_issues
}
fn detect_naming_convention(tensor_names: &[String]) -> Option<String> {
let patterns: &[(&str, &str)] = &[
(
"transformer.h.",
"GPT-2 / GPT-J / GPT-NeoX (uses 'transformer.h.{i}' prefix)",
),
(
"transformer.blocks.",
"Falcon / MPT (uses 'transformer.blocks.{i}' prefix)",
),
(
"gpt_neox.layers.",
"GPT-NeoX / Pythia (uses 'gpt_neox.layers.{i}' prefix)",
),
(
"transformer.layer.",
"BLOOM (uses 'transformer.layer.{i}' prefix)",
),
];
for &(prefix, description) in patterns {
if tensor_names.iter().any(|n| n.starts_with(prefix)) {
return Some(format!(
"this model uses {description} — candle-mi currently requires \
HF-standard 'model.layers.{{i}}' weight naming. \
Support for this architecture is planned in Phase 9 \
(tensor name remapping)"
));
}
}
if !tensor_names.iter().any(|n| n.starts_with("model.layers.")) {
let sample: Vec<&str> = tensor_names.iter().take(5).map(String::as_str).collect();
return Some(format!(
"weight tensors use an unrecognized naming convention \
(first 5: {}). candle-mi expects 'model.layers.{{i}}.self_attn.*' / \
'model.layers.{{i}}.mlp.*' naming",
sample.join(", ")
));
}
None
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn llama_config_json() -> Value {
serde_json::json!({
"model_type": "llama",
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"intermediate_size": 8192,
"vocab_size": 128256,
"rms_norm_eps": 1e-5,
"rope_theta": 500000.0,
"max_position_embeddings": 131072
})
}
#[test]
fn parse_llama_basic() {
let config = TransformerConfig::from_hf_config(&llama_config_json()).unwrap();
assert_eq!(config.hidden_size, 2048);
assert_eq!(config.num_layers, 16);
assert_eq!(config.num_attention_heads, 32);
assert_eq!(config.num_kv_heads, 8);
assert_eq!(config.head_dim, 64);
assert_eq!(config.intermediate_size, 8192);
assert_eq!(config.vocab_size, 128256);
assert_eq!(config.norm_type, NormType::RmsNorm);
assert_eq!(config.activation, Activation::Silu);
assert_eq!(config.qkv_layout, QkvLayout::Separate);
assert_eq!(config.mlp_layout, MlpLayout::GatedSeparate);
assert!(!config.qkv_bias);
assert!(!config.o_proj_bias);
assert!(!config.mlp_bias);
assert!(config.embedding_scale.is_none());
assert!(!config.tie_word_embeddings);
assert!((config.rope_theta - 500_000.0).abs() < f64::EPSILON);
assert!(config.attn_logit_softcapping.is_none());
assert!(config.sliding_window.is_none());
}
#[test]
fn parse_qwen2_bias() {
let json = serde_json::json!({
"model_type": "qwen2",
"hidden_size": 896,
"num_hidden_layers": 24,
"num_attention_heads": 14,
"num_key_value_heads": 2,
"intermediate_size": 4864,
"vocab_size": 151936,
"attention_bias": true,
"tie_word_embeddings": true
});
let config = TransformerConfig::from_hf_config(&json).unwrap();
assert!(config.qkv_bias);
assert!(!config.o_proj_bias);
assert!(config.tie_word_embeddings);
}
#[test]
fn parse_gemma2_extensions() {
let json = serde_json::json!({
"model_type": "gemma2",
"hidden_size": 2304,
"num_hidden_layers": 26,
"num_attention_heads": 8,
"num_key_value_heads": 4,
"head_dim": 256,
"intermediate_size": 9216,
"vocab_size": 256000,
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"query_pre_attn_scalar": 256,
"sliding_window": 4096
});
let config = TransformerConfig::from_hf_config(&json).unwrap();
assert_eq!(config.norm_type, NormType::GemmaRmsNorm);
assert_eq!(config.head_dim, 256);
assert!(config.embedding_scale.is_some());
assert!((config.attn_logit_softcapping.unwrap() - 50.0).abs() < f64::EPSILON);
assert!((config.final_logit_softcapping.unwrap() - 30.0).abs() < f64::EPSILON);
assert!((config.query_pre_attn_scalar.unwrap() - 256.0).abs() < f64::EPSILON);
assert!(config.use_post_norms);
assert_eq!(config.sliding_window, Some(4096));
assert!(config.alternating_sliding_window);
}
#[test]
fn parse_phi3_fused() {
let json = serde_json::json!({
"model_type": "phi3",
"hidden_size": 3072,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 32,
"intermediate_size": 8192,
"vocab_size": 32064
});
let config = TransformerConfig::from_hf_config(&json).unwrap();
assert_eq!(config.qkv_layout, QkvLayout::Fused);
assert_eq!(config.mlp_layout, MlpLayout::GatedFused);
}
#[test]
fn parse_starcoder2_bias_and_plain_mlp() {
let json = serde_json::json!({
"model_type": "starcoder2",
"hidden_size": 3072,
"num_hidden_layers": 30,
"num_attention_heads": 24,
"num_key_value_heads": 2,
"intermediate_size": 12288,
"vocab_size": 49152,
"use_bias": true,
"norm_type": "layer_norm"
});
let config = TransformerConfig::from_hf_config(&json).unwrap();
assert_eq!(config.mlp_layout, MlpLayout::Plain);
assert_eq!(config.activation, Activation::GeluApprox);
assert_eq!(config.norm_type, NormType::LayerNorm);
assert!(config.qkv_bias);
assert!(config.o_proj_bias);
assert!(config.mlp_bias);
}
#[test]
fn parse_mistral_sliding_window() {
let json = serde_json::json!({
"model_type": "mistral",
"hidden_size": 4096,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"intermediate_size": 14336,
"vocab_size": 32000,
"sliding_window": 4096
});
let config = TransformerConfig::from_hf_config(&json).unwrap();
assert_eq!(config.sliding_window, Some(4096));
assert!(!config.alternating_sliding_window);
}
#[test]
fn unsupported_model_type_errors() {
let json = serde_json::json!({ "model_type": "bert" });
let result = TransformerConfig::from_hf_config(&json);
assert!(result.is_err());
}
#[test]
fn missing_model_type_errors() {
let json = serde_json::json!({ "hidden_size": 768 });
let result = TransformerConfig::from_hf_config(&json);
assert!(result.is_err());
}
fn tensor_names(names: &[&str]) -> Vec<String> {
names.iter().map(|s| (*s).to_owned()).collect()
}
#[test]
fn auto_config_matches_llama() {
let json = serde_json::json!({
"model_type": "llama",
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 64,
"intermediate_size": 8192,
"vocab_size": 128256,
"rms_norm_eps": 1e-5,
"rope_theta": 500000.0,
"max_position_embeddings": 131072,
"hidden_act": "silu",
"attention_bias": false,
"mlp_bias": false,
"tie_word_embeddings": true
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.mlp.down_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.up_proj.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.v_proj.weight",
"model.norm.weight",
]);
let manual = TransformerConfig::from_hf_config(&json).unwrap();
let auto = TransformerConfig::parse_auto(&json, &names, "llama").unwrap();
assert_eq!(auto, manual);
}
#[test]
fn auto_config_matches_qwen2() {
let json = serde_json::json!({
"model_type": "qwen2",
"hidden_size": 2048,
"num_hidden_layers": 36,
"num_attention_heads": 16,
"num_key_value_heads": 2,
"intermediate_size": 11008,
"vocab_size": 151936,
"rms_norm_eps": 1e-6,
"rope_theta": 1000000.0,
"max_position_embeddings": 32768,
"hidden_act": "silu",
"tie_word_embeddings": true,
"sliding_window": 32768,
"use_sliding_window": false
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.mlp.down_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.up_proj.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.k_proj.bias",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.self_attn.q_proj.bias",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.v_proj.bias",
"model.layers.0.self_attn.v_proj.weight",
"model.norm.weight",
]);
let manual = TransformerConfig::from_hf_config(&json).unwrap();
let auto = TransformerConfig::parse_auto(&json, &names, "qwen2").unwrap();
assert_eq!(auto, manual);
}
#[test]
fn auto_config_matches_gemma() {
let json = serde_json::json!({
"model_type": "gemma",
"hidden_size": 3072,
"num_hidden_layers": 28,
"num_attention_heads": 16,
"num_key_value_heads": 16,
"head_dim": 256,
"intermediate_size": 24576,
"vocab_size": 256000,
"rms_norm_eps": 1e-6,
"rope_theta": 10000.0,
"max_position_embeddings": 8192,
"hidden_activation": "gelu_pytorch_tanh"
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.mlp.down_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.up_proj.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.v_proj.weight",
"model.norm.weight",
]);
let manual = TransformerConfig::from_hf_config(&json).unwrap();
let auto = TransformerConfig::parse_auto(&json, &names, "gemma").unwrap();
assert_eq!(auto, manual);
}
#[test]
fn auto_config_matches_gemma2() {
let json = serde_json::json!({
"model_type": "gemma2",
"hidden_size": 2304,
"num_hidden_layers": 26,
"num_attention_heads": 8,
"num_key_value_heads": 4,
"head_dim": 256,
"intermediate_size": 9216,
"vocab_size": 256000,
"rms_norm_eps": 1e-6,
"rope_theta": 10000.0,
"max_position_embeddings": 8192,
"hidden_act": "gelu_pytorch_tanh",
"hidden_activation": "gelu_pytorch_tanh",
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"query_pre_attn_scalar": 256,
"sliding_window": 4096
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.mlp.down_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.up_proj.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.post_feedforward_layernorm.weight",
"model.layers.0.pre_feedforward_layernorm.weight",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.v_proj.weight",
"model.norm.weight",
]);
let manual = TransformerConfig::from_hf_config(&json).unwrap();
let auto = TransformerConfig::parse_auto(&json, &names, "gemma2").unwrap();
assert_eq!(auto, manual);
}
#[test]
fn auto_config_matches_phi3() {
let json = serde_json::json!({
"model_type": "phi3",
"hidden_size": 3072,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 32,
"intermediate_size": 8192,
"vocab_size": 32064,
"rms_norm_eps": 1e-5,
"rope_theta": 10000.0,
"max_position_embeddings": 4096,
"hidden_act": "silu",
"tie_word_embeddings": false,
"sliding_window": 2047,
"attention_bias": false
});
let names = tensor_names(&[
"lm_head.weight",
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.mlp.down_proj.weight",
"model.layers.0.mlp.gate_up_proj.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.self_attn.qkv_proj.weight",
"model.norm.weight",
]);
let manual = TransformerConfig::from_hf_config(&json).unwrap();
let auto = TransformerConfig::parse_auto(&json, &names, "phi3").unwrap();
assert_eq!(manual.sliding_window, None);
assert_eq!(auto.sliding_window, Some(2047));
let mut auto_adjusted = auto;
auto_adjusted.sliding_window = None;
assert_eq!(auto_adjusted, manual);
}
#[test]
fn auto_config_matches_starcoder2() {
let json = serde_json::json!({
"model_type": "starcoder2",
"hidden_size": 3072,
"num_hidden_layers": 30,
"num_attention_heads": 24,
"num_key_value_heads": 2,
"intermediate_size": 12288,
"vocab_size": 49152,
"norm_epsilon": 1e-5,
"norm_type": "layer_norm",
"rope_theta": 999999.4420358813,
"max_position_embeddings": 16384,
"hidden_act": "gelu_pytorch_tanh",
"use_bias": true,
"sliding_window": 4096
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.bias",
"model.layers.0.input_layernorm.weight",
"model.layers.0.mlp.c_fc.bias",
"model.layers.0.mlp.c_fc.weight",
"model.layers.0.mlp.c_proj.bias",
"model.layers.0.mlp.c_proj.weight",
"model.layers.0.post_attention_layernorm.bias",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.k_proj.bias",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.o_proj.bias",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.self_attn.q_proj.bias",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.v_proj.bias",
"model.layers.0.self_attn.v_proj.weight",
"model.norm.bias",
"model.norm.weight",
]);
let manual = TransformerConfig::from_hf_config(&json).unwrap();
let auto = TransformerConfig::parse_auto(&json, &names, "starcoder2").unwrap();
assert_eq!(auto, manual);
}
#[test]
fn auto_config_matches_mistral() {
let json = serde_json::json!({
"model_type": "mistral",
"hidden_size": 4096,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"intermediate_size": 14336,
"vocab_size": 32000,
"rms_norm_eps": 1e-5,
"rope_theta": 10000.0,
"max_position_embeddings": 32768,
"hidden_act": "silu",
"tie_word_embeddings": false,
"sliding_window": 4096
});
let names = tensor_names(&[
"lm_head.weight",
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.mlp.down_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.up_proj.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.v_proj.weight",
"model.norm.weight",
]);
let manual = TransformerConfig::from_hf_config(&json).unwrap();
let auto = TransformerConfig::parse_auto(&json, &names, "mistral").unwrap();
assert_eq!(auto, manual);
}
#[test]
fn auto_config_unknown_model_type() {
let json = serde_json::json!({
"model_type": "my_custom_llama",
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"intermediate_size": 8192,
"vocab_size": 32000,
"rms_norm_eps": 1e-5,
"rope_theta": 10000.0,
"max_position_embeddings": 4096,
"hidden_act": "silu"
});
let names = tensor_names(&[
"lm_head.weight",
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.mlp.down_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.up_proj.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.v_proj.weight",
"model.norm.weight",
]);
let config = TransformerConfig::from_hf_config_auto(&json, &names).unwrap();
assert_eq!(config.hidden_size, 2048);
assert_eq!(config.num_layers, 16);
assert_eq!(config.num_attention_heads, 32);
assert_eq!(config.num_kv_heads, 8);
assert_eq!(config.head_dim, 64);
assert_eq!(config.norm_type, NormType::RmsNorm);
assert_eq!(config.activation, Activation::Silu);
assert_eq!(config.qkv_layout, QkvLayout::Separate);
assert_eq!(config.mlp_layout, MlpLayout::GatedSeparate);
assert!(!config.qkv_bias);
assert!(!config.o_proj_bias);
assert!(!config.mlp_bias);
assert!(config.embedding_scale.is_none());
assert!(!config.tie_word_embeddings);
assert!(config.sliding_window.is_none());
}
#[test]
fn auto_config_dispatches_known_families() {
let json = llama_config_json();
let names = tensor_names(&["model.embed_tokens.weight"]);
let auto = TransformerConfig::from_hf_config_auto(&json, &names).unwrap();
let manual = TransformerConfig::from_hf_config(&json).unwrap();
assert_eq!(auto, manual);
}
#[test]
fn compatibility_check_passes_standard_model() {
let json = serde_json::json!({
"model_type": "my_custom",
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 32,
"intermediate_size": 8192,
"vocab_size": 32000,
"tie_word_embeddings": true
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.norm.weight",
]);
let report = TransformerConfig::check_auto_compatibility(&json, &names);
assert!(report.compatible, "issues: {:?}", report.issues);
}
#[test]
fn compatibility_check_detects_missing_norms() {
let json = serde_json::json!({
"model_type": "olmo",
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 16,
"intermediate_size": 8192,
"vocab_size": 50304
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.down_proj.weight",
]);
let report = TransformerConfig::check_auto_compatibility(&json, &names);
assert!(!report.compatible);
assert!(report.issues.len() >= 3, "issues: {:?}", report.issues);
assert!(
report.issues.iter().any(|i| i.contains("input_layernorm")),
"should mention input_layernorm"
);
assert!(
report.issues.iter().any(|i| i.contains("model.norm")),
"should mention model.norm"
);
}
#[test]
fn compatibility_check_detects_missing_config_fields() {
let json = serde_json::json!({
"model_type": "mystery",
"hidden_size": 768
});
let names = tensor_names(&[]);
let report = TransformerConfig::check_auto_compatibility(&json, &names);
assert!(!report.compatible);
assert!(
report
.issues
.iter()
.any(|i| i.contains("num_hidden_layers")),
"should mention num_hidden_layers"
);
}
#[test]
fn compatibility_check_detects_missing_lm_head() {
let json = serde_json::json!({
"model_type": "custom",
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 32,
"intermediate_size": 8192,
"vocab_size": 32000,
"tie_word_embeddings": false
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.norm.weight",
]);
let report = TransformerConfig::check_auto_compatibility(&json, &names);
assert!(!report.compatible);
assert!(
report.issues.iter().any(|i| i.contains("lm_head")),
"should mention lm_head"
);
}
#[test]
fn compatibility_check_config_only() {
let good = serde_json::json!({
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 32,
"intermediate_size": 8192,
"vocab_size": 32000
});
assert!(TransformerConfig::check_config_fields(&good).compatible);
let bad = serde_json::json!({
"hidden_size": 2048
});
let report = TransformerConfig::check_config_fields(&bad);
assert!(!report.compatible);
assert_eq!(report.issues.len(), 4); }
#[test]
fn compatibility_into_result_error_message() {
let json = serde_json::json!({
"model_type": "olmo",
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 16,
"intermediate_size": 8192,
"vocab_size": 50304
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
]);
let result = TransformerConfig::check_auto_compatibility(&json, &names).into_result();
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("not compatible with GenericTransformer"),
"error should explain incompatibility: {msg}"
);
}
#[test]
fn compatibility_check_shows_gpt2_naming_hint() {
let json = serde_json::json!({
"model_type": "gpt2",
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"vocab_size": 50257
});
let names = tensor_names(&[
"transformer.wte.weight",
"transformer.wpe.weight",
"transformer.h.0.ln_1.weight",
"transformer.h.0.attn.c_attn.weight",
"transformer.h.0.mlp.c_fc.weight",
"transformer.ln_f.weight",
]);
let report = TransformerConfig::check_auto_compatibility(&json, &names);
assert!(!report.compatible);
assert!(
report.issues.iter().any(|i| i.contains("GPT-2")),
"should detect GPT-2 naming convention: {:?}",
report.issues
);
assert!(
report
.issues
.iter()
.any(|i| i.contains("transformer.wte.weight")),
"should show found embedding tensor: {:?}",
report.issues
);
assert!(
report.issues.iter().any(|i| i.contains("c_attn")),
"should show found attention tensor: {:?}",
report.issues
);
}
#[test]
fn compatibility_check_shows_found_tensors_for_unknown_naming() {
let json = serde_json::json!({
"model_type": "custom_arch",
"hidden_size": 512,
"num_hidden_layers": 6,
"num_attention_heads": 8,
"intermediate_size": 2048,
"vocab_size": 30000
});
let names = tensor_names(&[
"encoder.layer.0.attention.query.weight",
"encoder.layer.0.attention.key.weight",
"encoder.layer.0.ffn.dense.weight",
"encoder.embeddings.weight",
]);
let report = TransformerConfig::check_auto_compatibility(&json, &names);
assert!(!report.compatible);
assert!(
report
.issues
.iter()
.any(|i| i.contains("unrecognized naming convention")),
"should flag unrecognized naming: {:?}",
report.issues
);
assert!(
report
.issues
.iter()
.any(|i| i.contains("encoder.embeddings.weight")),
"should show found embedding: {:?}",
report.issues
);
}
#[test]
fn compatibility_check_shows_found_norm_tensors() {
let json = serde_json::json!({
"model_type": "custom",
"hidden_size": 2048,
"num_hidden_layers": 16,
"num_attention_heads": 32,
"intermediate_size": 8192,
"vocab_size": 32000,
"tie_word_embeddings": true
});
let names = tensor_names(&[
"model.embed_tokens.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.attention_norm.weight",
"model.layers.0.ffn_norm.weight",
"model.final_norm.weight",
]);
let report = TransformerConfig::check_auto_compatibility(&json, &names);
assert!(!report.compatible);
assert!(
report.issues.iter().any(|i| i.contains("attention_norm")),
"should show found norm tensors: {:?}",
report.issues
);
}
}