Skip to main content

entrenar/transformer/weights/
mapping.rs

1//! Weight name mapping between architectures
2
3use super::Architecture;
4
5/// Map RoBERTa/CodeBERT weight names to entrenar encoder convention (ENC-006).
6fn map_roberta_weight_name(name: &str) -> String {
7    // Strip roberta. or bert. prefix
8    let stripped =
9        name.strip_prefix("roberta.").or_else(|| name.strip_prefix("bert.")).unwrap_or(name);
10
11    // Embeddings
12    if stripped == "embeddings.word_embeddings.weight" {
13        return "encoder.embed_tokens.weight".to_string();
14    }
15    if stripped == "embeddings.position_embeddings.weight" {
16        return "encoder.position_embeddings.weight".to_string();
17    }
18    if stripped == "embeddings.token_type_embeddings.weight" {
19        return "encoder.token_type_embeddings.weight".to_string();
20    }
21    if stripped == "embeddings.LayerNorm.weight" {
22        return "encoder.embeddings_layernorm.weight".to_string();
23    }
24    if stripped == "embeddings.LayerNorm.bias" {
25        return "encoder.embeddings_layernorm.bias".to_string();
26    }
27
28    // Encoder layers: encoder.layer.{i}.XXX
29    if let Some(rest) = stripped.strip_prefix("encoder.layer.") {
30        if let Some((num, layer_rest)) = rest.split_once('.') {
31            let mapped = layer_rest
32                .replace("attention.self.query", "self_attn.q_proj")
33                .replace("attention.self.key", "self_attn.k_proj")
34                .replace("attention.self.value", "self_attn.v_proj")
35                .replace("attention.output.dense", "self_attn.o_proj")
36                .replace("attention.output.LayerNorm", "input_layernorm")
37                .replace("intermediate.dense", "mlp.intermediate.dense")
38                .replace("output.dense", "mlp.output.dense")
39                .replace("output.LayerNorm", "post_attention_layernorm");
40
41            return format!("encoder.layers.{num}.{mapped}");
42        }
43    }
44
45    // Pooler (optional)
46    if stripped.starts_with("pooler.") {
47        return format!("encoder.{stripped}");
48    }
49
50    // Pass through anything else
51    name.to_string()
52}
53
54/// Map GGUF tensor names to standard LLaMA/HF convention.
55///
56/// GGUF uses short names like `token_embd.weight`, `blk.0.attn_q.weight`.
57/// The training pipeline expects HF-style names like `model.embed_tokens.weight`.
58fn map_gguf_weight_name(name: &str) -> String {
59    // Embeddings
60    if name == "token_embd.weight" {
61        return "model.embed_tokens.weight".to_string();
62    }
63    if name == "output_norm.weight" {
64        return "model.norm.weight".to_string();
65    }
66    if name == "output_norm.bias" {
67        return "model.norm.bias".to_string();
68    }
69    if name == "output.weight" {
70        return "lm_head.weight".to_string();
71    }
72
73    // Layer tensors: blk.{N}.{component}.weight
74    if let Some(rest) = name.strip_prefix("blk.") {
75        if let Some((num, layer_rest)) = rest.split_once('.') {
76            let mapped = match layer_rest {
77                // Attention
78                "attn_q.weight" => "self_attn.q_proj.weight",
79                "attn_k.weight" => "self_attn.k_proj.weight",
80                "attn_v.weight" => "self_attn.v_proj.weight",
81                "attn_output.weight" => "self_attn.o_proj.weight",
82                "attn_q.bias" => "self_attn.q_proj.bias",
83                "attn_k.bias" => "self_attn.k_proj.bias",
84                "attn_v.bias" => "self_attn.v_proj.bias",
85                "attn_output.bias" => "self_attn.o_proj.bias",
86                // Norms
87                "attn_norm.weight" => "input_layernorm.weight",
88                "attn_norm.bias" => "input_layernorm.bias",
89                "ffn_norm.weight" => "post_attention_layernorm.weight",
90                "ffn_norm.bias" => "post_attention_layernorm.bias",
91                // FFN (Qwen2/LLaMA style)
92                "ffn_gate.weight" => "mlp.gate_proj.weight",
93                "ffn_up.weight" => "mlp.up_proj.weight",
94                "ffn_down.weight" => "mlp.down_proj.weight",
95                other => other,
96            };
97            return format!("model.layers.{num}.{mapped}");
98        }
99    }
100
101    // Pass through anything else
102    name.to_string()
103}
104
105/// Map weight name from source architecture to standard LLaMA convention
106///
107/// Standard names expected by `Transformer::from_params`:
108/// - `model.embed_tokens.weight`
109/// - `model.layers.{i}.input_layernorm.weight`
110/// - `model.layers.{i}.self_attn.q_proj.weight`
111/// - `model.layers.{i}.self_attn.k_proj.weight`
112/// - `model.layers.{i}.self_attn.v_proj.weight`
113/// - `model.layers.{i}.self_attn.o_proj.weight`
114/// - `model.layers.{i}.post_attention_layernorm.weight`
115/// - `model.layers.{i}.mlp.gate_proj.weight`
116/// - `model.layers.{i}.mlp.up_proj.weight`
117/// - `model.layers.{i}.mlp.down_proj.weight`
118/// - `model.norm.weight`
119/// - `lm_head.weight`
120pub(crate) fn map_weight_name(name: &str, arch: Architecture) -> String {
121    match arch {
122        Architecture::Qwen2 => {
123            // Qwen2 weight names are nearly identical to LLaMA convention
124            // Main differences:
125            // - Qwen2 has bias tensors (which we preserve)
126            // - Some Qwen2 models use "attn" instead of "self_attn" (rare)
127
128            // Handle potential "attn" -> "self_attn" mapping
129            if name.contains(".attn.") && !name.contains(".self_attn.") {
130                name.replace(".attn.", ".self_attn.")
131            } else {
132                name.to_string()
133            }
134        }
135        Architecture::Mistral | Architecture::Llama => {
136            // LLaMA and Mistral use same naming convention as our standard
137            name.to_string()
138        }
139        Architecture::RoBERTa => map_roberta_weight_name(name),
140        Architecture::Gguf => map_gguf_weight_name(name),
141        Architecture::Auto => {
142            // Should not reach here after detection
143            name.to_string()
144        }
145    }
146}