use super::Architecture;
fn map_roberta_weight_name(name: &str) -> String {
let stripped =
name.strip_prefix("roberta.").or_else(|| name.strip_prefix("bert.")).unwrap_or(name);
if stripped == "embeddings.word_embeddings.weight" {
return "encoder.embed_tokens.weight".to_string();
}
if stripped == "embeddings.position_embeddings.weight" {
return "encoder.position_embeddings.weight".to_string();
}
if stripped == "embeddings.token_type_embeddings.weight" {
return "encoder.token_type_embeddings.weight".to_string();
}
if stripped == "embeddings.LayerNorm.weight" {
return "encoder.embeddings_layernorm.weight".to_string();
}
if stripped == "embeddings.LayerNorm.bias" {
return "encoder.embeddings_layernorm.bias".to_string();
}
if let Some(rest) = stripped.strip_prefix("encoder.layer.") {
if let Some((num, layer_rest)) = rest.split_once('.') {
let mapped = layer_rest
.replace("attention.self.query", "self_attn.q_proj")
.replace("attention.self.key", "self_attn.k_proj")
.replace("attention.self.value", "self_attn.v_proj")
.replace("attention.output.dense", "self_attn.o_proj")
.replace("attention.output.LayerNorm", "input_layernorm")
.replace("intermediate.dense", "mlp.intermediate.dense")
.replace("output.dense", "mlp.output.dense")
.replace("output.LayerNorm", "post_attention_layernorm");
return format!("encoder.layers.{num}.{mapped}");
}
}
if stripped.starts_with("pooler.") {
return format!("encoder.{stripped}");
}
name.to_string()
}
fn map_gguf_weight_name(name: &str) -> String {
if name == "token_embd.weight" {
return "model.embed_tokens.weight".to_string();
}
if name == "output_norm.weight" {
return "model.norm.weight".to_string();
}
if name == "output_norm.bias" {
return "model.norm.bias".to_string();
}
if name == "output.weight" {
return "lm_head.weight".to_string();
}
if let Some(rest) = name.strip_prefix("blk.") {
if let Some((num, layer_rest)) = rest.split_once('.') {
let mapped = match layer_rest {
"attn_q.weight" => "self_attn.q_proj.weight",
"attn_k.weight" => "self_attn.k_proj.weight",
"attn_v.weight" => "self_attn.v_proj.weight",
"attn_output.weight" => "self_attn.o_proj.weight",
"attn_q.bias" => "self_attn.q_proj.bias",
"attn_k.bias" => "self_attn.k_proj.bias",
"attn_v.bias" => "self_attn.v_proj.bias",
"attn_output.bias" => "self_attn.o_proj.bias",
"attn_norm.weight" => "input_layernorm.weight",
"attn_norm.bias" => "input_layernorm.bias",
"ffn_norm.weight" => "post_attention_layernorm.weight",
"ffn_norm.bias" => "post_attention_layernorm.bias",
"ffn_gate.weight" => "mlp.gate_proj.weight",
"ffn_up.weight" => "mlp.up_proj.weight",
"ffn_down.weight" => "mlp.down_proj.weight",
other => other,
};
return format!("model.layers.{num}.{mapped}");
}
}
name.to_string()
}
pub(crate) fn map_weight_name(name: &str, arch: Architecture) -> String {
match arch {
Architecture::Qwen2 => {
if name.contains(".attn.") && !name.contains(".self_attn.") {
name.replace(".attn.", ".self_attn.")
} else {
name.to_string()
}
}
Architecture::Mistral | Architecture::Llama => {
name.to_string()
}
Architecture::RoBERTa => map_roberta_weight_name(name),
Architecture::Gguf => map_gguf_weight_name(name),
Architecture::Auto => {
name.to_string()
}
}
}