1use crate::config::LLaDA2MoeConfig;
19use anyhow::{Result, anyhow};
20use rlx_core::weight_loader::WeightLoader;
21use std::collections::{HashMap, HashSet};
22
23#[derive(Debug, Clone)]
24pub struct DenseFfnWeights {
25 pub gate: Vec<f32>,
26 pub up: Vec<f32>,
27 pub down: Vec<f32>,
28}
29
30#[derive(Debug, Clone)]
31pub struct MoeLayerWeights {
32 pub router: Vec<f32>,
33 pub expert_bias: Vec<f32>,
34 pub gate_exps: Vec<f32>,
35 pub up_exps: Vec<f32>,
36 pub down_exps: Vec<f32>,
37 pub shared_gate: Option<Vec<f32>>,
38 pub shared_up: Option<Vec<f32>>,
39 pub shared_down: Option<Vec<f32>>,
40}
41
42#[derive(Debug, Clone)]
43pub struct LayerWeights {
44 pub input_norm: Vec<f32>,
45 pub post_attn_norm: Vec<f32>,
46 pub qkv: Vec<f32>,
47 pub q_norm: Option<Vec<f32>>,
48 pub k_norm: Option<Vec<f32>>,
49 pub o_proj: Vec<f32>,
50 pub ffn: LayerFfn,
51}
52
53#[derive(Debug, Clone)]
54pub enum LayerFfn {
55 Dense(DenseFfnWeights),
56 Moe(MoeLayerWeights),
57}
58
59#[derive(Debug, Clone)]
60pub struct LLaDA2Weights {
61 pub embed: Vec<f32>,
62 pub final_norm: Vec<f32>,
63 pub lm_head: Vec<f32>,
64 pub layers: Vec<LayerWeights>,
65}
66
67pub fn tensor_keys_for_config(cfg: &LLaDA2MoeConfig) -> HashSet<String> {
69 let mut keys = HashSet::new();
70 keys.insert("model.word_embeddings.weight".into());
71 keys.insert("model.embed_tokens.weight".into());
72 keys.insert("model.norm.weight".into());
73 keys.insert("lm_head.weight".into());
74 for il in 0..cfg.num_hidden_layers {
75 keys.extend(layer_tensor_keys(cfg, il));
76 }
77 keys
78}
79
80fn layer_tensor_keys(cfg: &LLaDA2MoeConfig, il: usize) -> HashSet<String> {
81 let mut keys = HashSet::new();
82 let p = |tail: &str| format!("model.layers.{il}.{tail}");
83 for stem in ["attention", "self_attn"] {
84 keys.insert(p(&format!("{stem}.query_key_value.weight")));
85 keys.insert(p(&format!("{stem}.dense.weight")));
86 if cfg.use_qk_norm {
87 keys.insert(p(&format!("{stem}.query_layernorm.weight")));
88 keys.insert(p(&format!("{stem}.key_layernorm.weight")));
89 }
90 }
91 keys.insert(p("input_layernorm.weight"));
92 keys.insert(p("post_attention_layernorm.weight"));
93 if cfg.is_moe_layer(il) {
94 keys.insert(format!("model.layers.{il}.mlp.gate.weight"));
95 keys.insert(format!("model.layers.{il}.mlp.gate.expert_bias"));
96 for ei in 0..cfg.num_experts {
97 let base = format!("model.layers.{il}.mlp.experts.{ei}");
98 keys.insert(format!("{base}.gate_proj.weight"));
99 keys.insert(format!("{base}.up_proj.weight"));
100 keys.insert(format!("{base}.down_proj.weight"));
101 }
102 if cfg.num_shared_experts.unwrap_or(0) > 0 {
103 keys.insert(format!(
104 "model.layers.{il}.mlp.shared_experts.gate_proj.weight"
105 ));
106 keys.insert(format!(
107 "model.layers.{il}.mlp.shared_experts.up_proj.weight"
108 ));
109 keys.insert(format!(
110 "model.layers.{il}.mlp.shared_experts.down_proj.weight"
111 ));
112 }
113 } else {
114 keys.insert(p("mlp.gate_proj.weight"));
115 keys.insert(p("mlp.up_proj.weight"));
116 keys.insert(p("mlp.down_proj.weight"));
117 }
118 keys
119}
120
121fn take_any(loader: &mut dyn WeightLoader, keys: &[&str]) -> Result<(Vec<f32>, Vec<usize>)> {
122 for key in keys {
123 if let Ok(v) = loader.take(key) {
124 return Ok(v);
125 }
126 }
127 Err(anyhow!("weight not found: {}", keys.join(" | ")))
128}
129
130fn take_transposed_any(
131 loader: &mut dyn WeightLoader,
132 keys: &[&str],
133) -> Result<(Vec<f32>, Vec<usize>)> {
134 for key in keys {
135 if let Ok(v) = loader.take_transposed(key) {
136 return Ok(v);
137 }
138 }
139 Err(anyhow!("weight not found: {}", keys.join(" | ")))
140}
141
142impl LLaDA2Weights {
143 pub fn load(cfg: &LLaDA2MoeConfig, loader: &mut dyn WeightLoader) -> Result<Self> {
144 let h = cfg.hidden_size;
145 let vocab = cfg.vocab_size;
146 let embed = take_any(
147 loader,
148 &["model.word_embeddings.weight", "model.embed_tokens.weight"],
149 )?
150 .0;
151 let final_norm = loader.take("model.norm.weight")?.0;
152 let lm_head = take_any(
153 loader,
154 &[
155 "lm_head.weight",
156 "model.word_embeddings.weight",
157 "model.embed_tokens.weight",
158 ],
159 )?
160 .0;
161
162 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
163 for il in 0..cfg.num_hidden_layers {
164 layers.push(load_layer(cfg, loader, il)?);
165 }
166
167 if embed.len() != vocab * h {
168 return Err(anyhow!(
169 "embed len {} != vocab*hidden ({vocab}*{h})",
170 embed.len()
171 ));
172 }
173 Ok(Self {
174 embed,
175 final_norm,
176 lm_head,
177 layers,
178 })
179 }
180}
181
182fn load_layer(
183 cfg: &LLaDA2MoeConfig,
184 loader: &mut dyn WeightLoader,
185 il: usize,
186) -> Result<LayerWeights> {
187 let p = |tail: &str| format!("model.layers.{il}.{tail}");
188 let h = cfg.hidden_size;
189 let qkv_out = (cfg.num_attention_heads + 2 * cfg.num_kv_heads()) * cfg.head_dim();
190
191 let qkv = take_transposed_any(
192 loader,
193 &[
194 &p("attention.query_key_value.weight"),
195 &p("self_attn.query_key_value.weight"),
196 ],
197 )?
198 .0;
199 let o_proj = take_transposed_any(
200 loader,
201 &[&p("attention.dense.weight"), &p("self_attn.dense.weight")],
202 )?
203 .0;
204
205 let q_norm = if cfg.use_qk_norm {
206 Some(
207 take_any(
208 loader,
209 &[
210 &p("attention.query_layernorm.weight"),
211 &p("self_attn.query_layernorm.weight"),
212 ],
213 )?
214 .0,
215 )
216 } else {
217 None
218 };
219 let k_norm = if cfg.use_qk_norm {
220 Some(
221 take_any(
222 loader,
223 &[
224 &p("attention.key_layernorm.weight"),
225 &p("self_attn.key_layernorm.weight"),
226 ],
227 )?
228 .0,
229 )
230 } else {
231 None
232 };
233
234 if qkv.len() != h * qkv_out {
235 return Err(anyhow!("layer {il} qkv size mismatch"));
236 }
237
238 let ffn = if cfg.is_moe_layer(il) {
239 let e = cfg.num_experts;
240 let ff = cfg.expert_ffn_dim();
241 let router =
242 take_transposed_any(loader, &[&format!("model.layers.{il}.mlp.gate.weight")])?.0;
243 let expert_bias = loader
244 .take(&format!("model.layers.{il}.mlp.gate.expert_bias"))
245 .map(|(d, _)| d)
246 .unwrap_or_else(|_| vec![0f32; e]);
247 let mut gate_exps = vec![0f32; e * h * ff];
248 let mut up_exps = vec![0f32; e * h * ff];
249 let mut down_exps = vec![0f32; e * ff * h];
250 for ei in 0..e {
251 let base = format!("model.layers.{il}.mlp.experts.{ei}");
252 let g = take_transposed_any(loader, &[&format!("{base}.gate_proj.weight")])?.0;
253 let u = take_transposed_any(loader, &[&format!("{base}.up_proj.weight")])?.0;
254 let d = take_transposed_any(loader, &[&format!("{base}.down_proj.weight")])?.0;
255 let stride_in = h * ff;
256 let stride_out = ff * h;
257 gate_exps[ei * stride_in..(ei + 1) * stride_in].copy_from_slice(&g);
258 up_exps[ei * stride_in..(ei + 1) * stride_in].copy_from_slice(&u);
259 down_exps[ei * stride_out..(ei + 1) * stride_out].copy_from_slice(&d);
260 }
261 let (shared_gate, shared_up, shared_down) = if cfg.num_shared_experts.unwrap_or(0) > 0 {
262 let sg = take_transposed_any(
263 loader,
264 &[&format!(
265 "model.layers.{il}.mlp.shared_experts.gate_proj.weight"
266 )],
267 )?
268 .0;
269 let su = take_transposed_any(
270 loader,
271 &[&format!(
272 "model.layers.{il}.mlp.shared_experts.up_proj.weight"
273 )],
274 )?
275 .0;
276 let sd = take_transposed_any(
277 loader,
278 &[&format!(
279 "model.layers.{il}.mlp.shared_experts.down_proj.weight"
280 )],
281 )?
282 .0;
283 (Some(sg), Some(su), Some(sd))
284 } else {
285 (None, None, None)
286 };
287 LayerFfn::Moe(MoeLayerWeights {
288 router,
289 expert_bias,
290 gate_exps,
291 up_exps,
292 down_exps,
293 shared_gate,
294 shared_up,
295 shared_down,
296 })
297 } else {
298 LayerFfn::Dense(DenseFfnWeights {
299 gate: take_transposed_any(loader, &[&p("mlp.gate_proj.weight")])?.0,
300 up: take_transposed_any(loader, &[&p("mlp.up_proj.weight")])?.0,
301 down: take_transposed_any(loader, &[&p("mlp.down_proj.weight")])?.0,
302 })
303 };
304
305 Ok(LayerWeights {
306 input_norm: loader.take(&p("input_layernorm.weight"))?.0,
307 post_attn_norm: loader.take(&p("post_attention_layernorm.weight"))?.0,
308 qkv,
309 q_norm,
310 k_norm,
311 o_proj,
312 ffn,
313 })
314}
315
316pub fn register_params(
318 cfg: &LLaDA2MoeConfig,
319 weights: &LLaDA2Weights,
320 params: &mut HashMap<String, Vec<f32>>,
321) {
322 params.insert("model.embed_tokens.weight".into(), weights.embed.clone());
323 params.insert("model.norm.weight".into(), weights.final_norm.clone());
324 params.insert("lm_head.weight".into(), weights.lm_head.clone());
325 let inv = crate::rope::inv_freq(cfg);
326 let (cos, sin) = crate::rope::build_rope_tables(cfg, &inv, cfg.max_position_embeddings);
327 params.insert("rope.cos".into(), cos);
328 params.insert("rope.sin".into(), sin);
329 for (il, layer) in weights.layers.iter().enumerate() {
330 let p = |t: &str| format!("model.layers.{il}.{t}");
331 params.insert(p("input_layernorm.weight"), layer.input_norm.clone());
332 params.insert(
333 p("post_attention_layernorm.weight"),
334 layer.post_attn_norm.clone(),
335 );
336 params.insert(p("self_attn.query_key_value.weight"), layer.qkv.clone());
337 params.insert(p("self_attn.dense.weight"), layer.o_proj.clone());
338 if let Some(q) = &layer.q_norm {
339 params.insert(p("self_attn.query_layernorm.weight"), q.clone());
340 }
341 if let Some(k) = &layer.k_norm {
342 params.insert(p("self_attn.key_layernorm.weight"), k.clone());
343 }
344 match &layer.ffn {
345 LayerFfn::Dense(d) => {
346 params.insert(p("mlp.gate_proj.weight"), d.gate.clone());
347 params.insert(p("mlp.up_proj.weight"), d.up.clone());
348 params.insert(p("mlp.down_proj.weight"), d.down.clone());
349 }
350 LayerFfn::Moe(m) => {
351 params.insert(p("mlp.gate.weight"), m.router.clone());
352 params.insert(p("mlp.gate.expert_bias"), m.expert_bias.clone());
353 params.insert(p("mlp.gate_exps.weight"), m.gate_exps.clone());
354 params.insert(p("mlp.up_exps.weight"), m.up_exps.clone());
355 params.insert(p("mlp.down_exps.weight"), m.down_exps.clone());
356 if let Some(w) = &m.shared_gate {
357 params.insert(p("mlp.shared_experts.gate_proj.weight"), w.clone());
358 }
359 if let Some(w) = &m.shared_up {
360 params.insert(p("mlp.shared_experts.up_proj.weight"), w.clone());
361 }
362 if let Some(w) = &m.shared_down {
363 params.insert(p("mlp.shared_experts.down_proj.weight"), w.clone());
364 }
365 }
366 }
367 }
368}