use std::collections::HashMap;
use crate::model::Architecture;
pub struct TensorNameMapper {
hf_to_internal: HashMap<String, String>,
internal_to_hf: HashMap<String, String>,
}
impl TensorNameMapper {
pub fn from_tensor_names(hf_names: &[String], architecture: Architecture) -> Self {
let mut hf_to_internal = HashMap::new();
let mut internal_to_hf = HashMap::new();
let is_gemma = architecture.is_gemma();
for hf_name in hf_names {
if let Some(internal) = Self::map_hf_to_internal(hf_name, is_gemma) {
hf_to_internal.insert(hf_name.clone(), internal.clone());
internal_to_hf.insert(internal, hf_name.clone());
}
}
Self { hf_to_internal, internal_to_hf }
}
pub fn to_internal(&self, hf_name: &str) -> Option<&str> {
self.hf_to_internal.get(hf_name).map(|s| s.as_str())
}
pub fn to_hf(&self, internal_name: &str) -> Option<&str> {
self.internal_to_hf.get(internal_name).map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.hf_to_internal.len()
}
pub fn is_empty(&self) -> bool {
self.hf_to_internal.is_empty()
}
fn map_hf_to_internal(hf_name: &str, is_gemma: bool) -> Option<String> {
let normalized = hf_name
.strip_prefix("model.language_model.")
.or_else(|| hf_name.strip_prefix("model."))
.unwrap_or(hf_name);
match normalized {
"embed_tokens.weight" => return Some("token_embd.weight".into()),
"norm.weight" => return Some("output_norm.weight".into()),
"per_layer_model_projection.weight" | "per_layer_model_proj.weight" =>
return Some("per_layer_model_proj.weight".into()),
"embed_tokens_per_layer.weight" | "per_layer_token_embd.weight" =>
return Some("per_layer_token_embd.weight".into()),
"per_layer_projection_norm.weight" | "per_layer_proj_norm.weight" =>
return Some("per_layer_proj_norm.weight".into()),
_ => {}
}
if hf_name == "lm_head.weight" {
return Some("output.weight".into());
}
if let Some(rest) = normalized.strip_prefix("layers.") {
let dot_pos = rest.find('.')?;
let layer_num: usize = rest[..dot_pos].parse().ok()?;
let suffix = &rest[dot_pos + 1..];
let internal_suffix = match suffix {
"self_attn.q_proj.weight" => "attn_q.weight",
"self_attn.k_proj.weight" => "attn_k.weight",
"self_attn.v_proj.weight" => "attn_v.weight",
"self_attn.o_proj.weight" => "attn_output.weight",
"self_attn.q_proj.bias" => "attn_q.bias",
"self_attn.k_proj.bias" => "attn_k.bias",
"self_attn.v_proj.bias" => "attn_v.bias",
"self_attn.o_proj.bias" => "attn_output.bias",
"self_attn.q_norm.weight" => "attn_q_norm.weight",
"self_attn.k_norm.weight" => "attn_k_norm.weight",
"mlp.gate_proj.weight" => "ffn_gate.weight",
"mlp.up_proj.weight" => "ffn_up.weight",
"mlp.down_proj.weight" => "ffn_down.weight",
"mlp.gate_proj.bias" => "ffn_gate.bias",
"mlp.up_proj.bias" => "ffn_up.bias",
"mlp.down_proj.bias" => "ffn_down.bias",
"input_layernorm.weight" => "attn_norm.weight",
"post_attention_layernorm.weight" => {
if is_gemma { "post_attention_norm.weight" } else { "ffn_norm.weight" }
}
"pre_feedforward_layernorm.weight" => "ffn_norm.weight",
"post_feedforward_layernorm.weight" => "post_ffw_norm.weight",
"per_layer_input_gate.weight" => "inp_gate.weight",
"per_layer_projection.weight" | "per_layer_proj.weight" => "proj.weight",
"post_per_layer_input_norm.weight" | "per_layer_post_norm.weight" =>
"post_norm.weight",
"layer_scalar" | "layer_output_scale" => "layer_output_scale.weight",
_ => return None,
};
return Some(format!("blk.{}.{}", layer_num, internal_suffix));
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mapper(names: &[&str], arch: Architecture) -> TensorNameMapper {
let names: Vec<String> = names.iter().map(|s| s.to_string()).collect();
TensorNameMapper::from_tensor_names(&names, arch)
}
#[test]
fn test_llama_basic_mapping() {
let names = vec![
"model.embed_tokens.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.v_proj.weight",
"model.layers.0.self_attn.o_proj.weight",
"model.layers.0.mlp.gate_proj.weight",
"model.layers.0.mlp.up_proj.weight",
"model.layers.0.mlp.down_proj.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.norm.weight",
"lm_head.weight",
];
let m = make_mapper(&names, Architecture::Llama3);
assert_eq!(m.to_internal("model.embed_tokens.weight"), Some("token_embd.weight"));
assert_eq!(m.to_internal("model.layers.0.self_attn.q_proj.weight"), Some("blk.0.attn_q.weight"));
assert_eq!(m.to_internal("model.layers.0.mlp.gate_proj.weight"), Some("blk.0.ffn_gate.weight"));
assert_eq!(m.to_internal("model.layers.0.post_attention_layernorm.weight"), Some("blk.0.ffn_norm.weight"));
assert_eq!(m.to_internal("model.norm.weight"), Some("output_norm.weight"));
assert_eq!(m.to_internal("lm_head.weight"), Some("output.weight"));
assert_eq!(m.to_hf("token_embd.weight"), Some("model.embed_tokens.weight"));
assert_eq!(m.to_hf("blk.0.attn_q.weight"), Some("model.layers.0.self_attn.q_proj.weight"));
}
#[test]
fn test_gemma4_actual_safetensors_names() {
let names = vec![
"model.language_model.embed_tokens.weight",
"model.language_model.norm.weight",
"model.language_model.layers.0.self_attn.q_proj.weight",
"model.language_model.layers.0.self_attn.k_proj.weight",
"model.language_model.layers.0.self_attn.v_proj.weight",
"model.language_model.layers.0.self_attn.o_proj.weight",
"model.language_model.layers.0.self_attn.q_norm.weight",
"model.language_model.layers.0.self_attn.k_norm.weight",
"model.language_model.layers.0.mlp.gate_proj.weight",
"model.language_model.layers.0.mlp.up_proj.weight",
"model.language_model.layers.0.mlp.down_proj.weight",
"model.language_model.layers.0.input_layernorm.weight",
"model.language_model.layers.0.post_attention_layernorm.weight",
"model.language_model.layers.0.pre_feedforward_layernorm.weight",
"model.language_model.layers.0.post_feedforward_layernorm.weight",
"model.language_model.layers.0.per_layer_input_gate.weight",
"model.language_model.layers.0.per_layer_projection.weight",
"model.language_model.layers.0.post_per_layer_input_norm.weight",
"model.language_model.layers.0.layer_scalar",
"model.language_model.embed_tokens_per_layer.weight",
"model.language_model.per_layer_model_projection.weight",
"model.language_model.per_layer_projection_norm.weight",
];
let m = make_mapper(&names, Architecture::Gemma4);
assert_eq!(m.to_internal("model.language_model.embed_tokens.weight"), Some("token_embd.weight"));
assert_eq!(m.to_internal("model.language_model.norm.weight"), Some("output_norm.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.self_attn.q_proj.weight"), Some("blk.0.attn_q.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.self_attn.o_proj.weight"), Some("blk.0.attn_output.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.self_attn.q_norm.weight"), Some("blk.0.attn_q_norm.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.mlp.gate_proj.weight"), Some("blk.0.ffn_gate.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.input_layernorm.weight"), Some("blk.0.attn_norm.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.post_attention_layernorm.weight"), Some("blk.0.post_attention_norm.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.pre_feedforward_layernorm.weight"), Some("blk.0.ffn_norm.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.post_feedforward_layernorm.weight"), Some("blk.0.post_ffw_norm.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.per_layer_input_gate.weight"), Some("blk.0.inp_gate.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.per_layer_projection.weight"), Some("blk.0.proj.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.post_per_layer_input_norm.weight"), Some("blk.0.post_norm.weight"));
assert_eq!(m.to_internal("model.language_model.layers.0.layer_scalar"), Some("blk.0.layer_output_scale.weight"));
assert_eq!(m.to_internal("model.language_model.embed_tokens_per_layer.weight"), Some("per_layer_token_embd.weight"));
assert_eq!(m.to_internal("model.language_model.per_layer_model_projection.weight"), Some("per_layer_model_proj.weight"));
assert_eq!(m.to_internal("model.language_model.per_layer_projection_norm.weight"), Some("per_layer_proj_norm.weight"));
}
#[test]
fn test_multi_layer_mapping() {
let names = vec![
"model.layers.0.self_attn.q_proj.weight",
"model.layers.15.self_attn.q_proj.weight",
"model.layers.34.self_attn.q_proj.weight",
];
let m = make_mapper(&names, Architecture::Llama3);
assert_eq!(m.to_internal("model.layers.15.self_attn.q_proj.weight"), Some("blk.15.attn_q.weight"));
assert_eq!(m.to_internal("model.layers.34.self_attn.q_proj.weight"), Some("blk.34.attn_q.weight"));
}
#[test]
fn test_unknown_name_returns_none() {
let names = vec!["some.random.tensor"];
let m = make_mapper(&names, Architecture::Llama3);
assert_eq!(m.to_internal("some.random.tensor"), None);
assert_eq!(m.to_hf("nonexistent"), None);
}
#[test]
fn test_gemma_vs_llama_post_attention_norm() {
let names = vec!["model.layers.0.post_attention_layernorm.weight"];
let llama = make_mapper(&names, Architecture::Llama3);
assert_eq!(llama.to_internal("model.layers.0.post_attention_layernorm.weight"), Some("blk.0.ffn_norm.weight"));
let gemma = make_mapper(&names, Architecture::Gemma4);
assert_eq!(gemma.to_internal("model.layers.0.post_attention_layernorm.weight"), Some("blk.0.post_attention_norm.weight"));
}
#[test]
fn test_bias_tensors() {
let names = vec![
"model.layers.0.self_attn.q_proj.bias",
"model.layers.0.mlp.gate_proj.bias",
];
let m = make_mapper(&names, Architecture::Llama3);
assert_eq!(m.to_internal("model.layers.0.self_attn.q_proj.bias"), Some("blk.0.attn_q.bias"));
assert_eq!(m.to_internal("model.layers.0.mlp.gate_proj.bias"), Some("blk.0.ffn_gate.bias"));
}
}