entrenar/transformer/weights/
mapping.rs1use super::Architecture;
4
5fn map_roberta_weight_name(name: &str) -> String {
7 let stripped =
9 name.strip_prefix("roberta.").or_else(|| name.strip_prefix("bert.")).unwrap_or(name);
10
11 if stripped == "embeddings.word_embeddings.weight" {
13 return "encoder.embed_tokens.weight".to_string();
14 }
15 if stripped == "embeddings.position_embeddings.weight" {
16 return "encoder.position_embeddings.weight".to_string();
17 }
18 if stripped == "embeddings.token_type_embeddings.weight" {
19 return "encoder.token_type_embeddings.weight".to_string();
20 }
21 if stripped == "embeddings.LayerNorm.weight" {
22 return "encoder.embeddings_layernorm.weight".to_string();
23 }
24 if stripped == "embeddings.LayerNorm.bias" {
25 return "encoder.embeddings_layernorm.bias".to_string();
26 }
27
28 if let Some(rest) = stripped.strip_prefix("encoder.layer.") {
30 if let Some((num, layer_rest)) = rest.split_once('.') {
31 let mapped = layer_rest
32 .replace("attention.self.query", "self_attn.q_proj")
33 .replace("attention.self.key", "self_attn.k_proj")
34 .replace("attention.self.value", "self_attn.v_proj")
35 .replace("attention.output.dense", "self_attn.o_proj")
36 .replace("attention.output.LayerNorm", "input_layernorm")
37 .replace("intermediate.dense", "mlp.intermediate.dense")
38 .replace("output.dense", "mlp.output.dense")
39 .replace("output.LayerNorm", "post_attention_layernorm");
40
41 return format!("encoder.layers.{num}.{mapped}");
42 }
43 }
44
45 if stripped.starts_with("pooler.") {
47 return format!("encoder.{stripped}");
48 }
49
50 name.to_string()
52}
53
54fn map_gguf_weight_name(name: &str) -> String {
59 if name == "token_embd.weight" {
61 return "model.embed_tokens.weight".to_string();
62 }
63 if name == "output_norm.weight" {
64 return "model.norm.weight".to_string();
65 }
66 if name == "output_norm.bias" {
67 return "model.norm.bias".to_string();
68 }
69 if name == "output.weight" {
70 return "lm_head.weight".to_string();
71 }
72
73 if let Some(rest) = name.strip_prefix("blk.") {
75 if let Some((num, layer_rest)) = rest.split_once('.') {
76 let mapped = match layer_rest {
77 "attn_q.weight" => "self_attn.q_proj.weight",
79 "attn_k.weight" => "self_attn.k_proj.weight",
80 "attn_v.weight" => "self_attn.v_proj.weight",
81 "attn_output.weight" => "self_attn.o_proj.weight",
82 "attn_q.bias" => "self_attn.q_proj.bias",
83 "attn_k.bias" => "self_attn.k_proj.bias",
84 "attn_v.bias" => "self_attn.v_proj.bias",
85 "attn_output.bias" => "self_attn.o_proj.bias",
86 "attn_norm.weight" => "input_layernorm.weight",
88 "attn_norm.bias" => "input_layernorm.bias",
89 "ffn_norm.weight" => "post_attention_layernorm.weight",
90 "ffn_norm.bias" => "post_attention_layernorm.bias",
91 "ffn_gate.weight" => "mlp.gate_proj.weight",
93 "ffn_up.weight" => "mlp.up_proj.weight",
94 "ffn_down.weight" => "mlp.down_proj.weight",
95 other => other,
96 };
97 return format!("model.layers.{num}.{mapped}");
98 }
99 }
100
101 name.to_string()
103}
104
105pub(crate) fn map_weight_name(name: &str, arch: Architecture) -> String {
121 match arch {
122 Architecture::Qwen2 => {
123 if name.contains(".attn.") && !name.contains(".self_attn.") {
130 name.replace(".attn.", ".self_attn.")
131 } else {
132 name.to_string()
133 }
134 }
135 Architecture::Mistral | Architecture::Llama => {
136 name.to_string()
138 }
139 Architecture::RoBERTa => map_roberta_weight_name(name),
140 Architecture::Gguf => map_gguf_weight_name(name),
141 Architecture::Auto => {
142 name.to_string()
144 }
145 }
146}