Skip to main content

rlx_llada2/llada2/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// RLX — LLaDA2 MoE weight layout (HuggingFace / TIDE naming).
17
18use crate::config::LLaDA2MoeConfig;
19use anyhow::{Result, anyhow};
20use rlx_core::weight_loader::WeightLoader;
21use std::collections::{HashMap, HashSet};
22
23#[derive(Debug, Clone)]
24pub struct DenseFfnWeights {
25    pub gate: Vec<f32>,
26    pub up: Vec<f32>,
27    pub down: Vec<f32>,
28}
29
30#[derive(Debug, Clone)]
31pub struct MoeLayerWeights {
32    pub router: Vec<f32>,
33    pub expert_bias: Vec<f32>,
34    pub gate_exps: Vec<f32>,
35    pub up_exps: Vec<f32>,
36    pub down_exps: Vec<f32>,
37    pub shared_gate: Option<Vec<f32>>,
38    pub shared_up: Option<Vec<f32>>,
39    pub shared_down: Option<Vec<f32>>,
40}
41
42#[derive(Debug, Clone)]
43pub struct LayerWeights {
44    pub input_norm: Vec<f32>,
45    pub post_attn_norm: Vec<f32>,
46    pub qkv: Vec<f32>,
47    pub q_norm: Option<Vec<f32>>,
48    pub k_norm: Option<Vec<f32>>,
49    pub o_proj: Vec<f32>,
50    pub ffn: LayerFfn,
51}
52
53#[derive(Debug, Clone)]
54pub enum LayerFfn {
55    Dense(DenseFfnWeights),
56    Moe(MoeLayerWeights),
57}
58
59#[derive(Debug, Clone)]
60pub struct LLaDA2Weights {
61    pub embed: Vec<f32>,
62    pub final_norm: Vec<f32>,
63    pub lm_head: Vec<f32>,
64    pub layers: Vec<LayerWeights>,
65}
66
67/// HF tensor names required to build a graph with `cfg.num_hidden_layers` blocks.
68pub fn tensor_keys_for_config(cfg: &LLaDA2MoeConfig) -> HashSet<String> {
69    let mut keys = HashSet::new();
70    keys.insert("model.word_embeddings.weight".into());
71    keys.insert("model.embed_tokens.weight".into());
72    keys.insert("model.norm.weight".into());
73    keys.insert("lm_head.weight".into());
74    for il in 0..cfg.num_hidden_layers {
75        keys.extend(layer_tensor_keys(cfg, il));
76    }
77    keys
78}
79
80fn layer_tensor_keys(cfg: &LLaDA2MoeConfig, il: usize) -> HashSet<String> {
81    let mut keys = HashSet::new();
82    let p = |tail: &str| format!("model.layers.{il}.{tail}");
83    for stem in ["attention", "self_attn"] {
84        keys.insert(p(&format!("{stem}.query_key_value.weight")));
85        keys.insert(p(&format!("{stem}.dense.weight")));
86        if cfg.use_qk_norm {
87            keys.insert(p(&format!("{stem}.query_layernorm.weight")));
88            keys.insert(p(&format!("{stem}.key_layernorm.weight")));
89        }
90    }
91    keys.insert(p("input_layernorm.weight"));
92    keys.insert(p("post_attention_layernorm.weight"));
93    if cfg.is_moe_layer(il) {
94        keys.insert(format!("model.layers.{il}.mlp.gate.weight"));
95        keys.insert(format!("model.layers.{il}.mlp.gate.expert_bias"));
96        for ei in 0..cfg.num_experts {
97            let base = format!("model.layers.{il}.mlp.experts.{ei}");
98            keys.insert(format!("{base}.gate_proj.weight"));
99            keys.insert(format!("{base}.up_proj.weight"));
100            keys.insert(format!("{base}.down_proj.weight"));
101        }
102        if cfg.num_shared_experts.unwrap_or(0) > 0 {
103            keys.insert(format!(
104                "model.layers.{il}.mlp.shared_experts.gate_proj.weight"
105            ));
106            keys.insert(format!(
107                "model.layers.{il}.mlp.shared_experts.up_proj.weight"
108            ));
109            keys.insert(format!(
110                "model.layers.{il}.mlp.shared_experts.down_proj.weight"
111            ));
112        }
113    } else {
114        keys.insert(p("mlp.gate_proj.weight"));
115        keys.insert(p("mlp.up_proj.weight"));
116        keys.insert(p("mlp.down_proj.weight"));
117    }
118    keys
119}
120
121fn take_any(loader: &mut dyn WeightLoader, keys: &[&str]) -> Result<(Vec<f32>, Vec<usize>)> {
122    for key in keys {
123        if let Ok(v) = loader.take(key) {
124            return Ok(v);
125        }
126    }
127    Err(anyhow!("weight not found: {}", keys.join(" | ")))
128}
129
130fn take_transposed_any(
131    loader: &mut dyn WeightLoader,
132    keys: &[&str],
133) -> Result<(Vec<f32>, Vec<usize>)> {
134    for key in keys {
135        if let Ok(v) = loader.take_transposed(key) {
136            return Ok(v);
137        }
138    }
139    Err(anyhow!("weight not found: {}", keys.join(" | ")))
140}
141
142impl LLaDA2Weights {
143    pub fn load(cfg: &LLaDA2MoeConfig, loader: &mut dyn WeightLoader) -> Result<Self> {
144        let h = cfg.hidden_size;
145        let vocab = cfg.vocab_size;
146        let embed = take_any(
147            loader,
148            &["model.word_embeddings.weight", "model.embed_tokens.weight"],
149        )?
150        .0;
151        let final_norm = loader.take("model.norm.weight")?.0;
152        let lm_head = take_any(
153            loader,
154            &[
155                "lm_head.weight",
156                "model.word_embeddings.weight",
157                "model.embed_tokens.weight",
158            ],
159        )?
160        .0;
161
162        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
163        for il in 0..cfg.num_hidden_layers {
164            layers.push(load_layer(cfg, loader, il)?);
165        }
166
167        if embed.len() != vocab * h {
168            return Err(anyhow!(
169                "embed len {} != vocab*hidden ({vocab}*{h})",
170                embed.len()
171            ));
172        }
173        Ok(Self {
174            embed,
175            final_norm,
176            lm_head,
177            layers,
178        })
179    }
180}
181
182fn load_layer(
183    cfg: &LLaDA2MoeConfig,
184    loader: &mut dyn WeightLoader,
185    il: usize,
186) -> Result<LayerWeights> {
187    let p = |tail: &str| format!("model.layers.{il}.{tail}");
188    let h = cfg.hidden_size;
189    let qkv_out = (cfg.num_attention_heads + 2 * cfg.num_kv_heads()) * cfg.head_dim();
190
191    let qkv = take_transposed_any(
192        loader,
193        &[
194            &p("attention.query_key_value.weight"),
195            &p("self_attn.query_key_value.weight"),
196        ],
197    )?
198    .0;
199    let o_proj = take_transposed_any(
200        loader,
201        &[&p("attention.dense.weight"), &p("self_attn.dense.weight")],
202    )?
203    .0;
204
205    let q_norm = if cfg.use_qk_norm {
206        Some(
207            take_any(
208                loader,
209                &[
210                    &p("attention.query_layernorm.weight"),
211                    &p("self_attn.query_layernorm.weight"),
212                ],
213            )?
214            .0,
215        )
216    } else {
217        None
218    };
219    let k_norm = if cfg.use_qk_norm {
220        Some(
221            take_any(
222                loader,
223                &[
224                    &p("attention.key_layernorm.weight"),
225                    &p("self_attn.key_layernorm.weight"),
226                ],
227            )?
228            .0,
229        )
230    } else {
231        None
232    };
233
234    if qkv.len() != h * qkv_out {
235        return Err(anyhow!("layer {il} qkv size mismatch"));
236    }
237
238    let ffn = if cfg.is_moe_layer(il) {
239        let e = cfg.num_experts;
240        let ff = cfg.expert_ffn_dim();
241        let router =
242            take_transposed_any(loader, &[&format!("model.layers.{il}.mlp.gate.weight")])?.0;
243        let expert_bias = loader
244            .take(&format!("model.layers.{il}.mlp.gate.expert_bias"))
245            .map(|(d, _)| d)
246            .unwrap_or_else(|_| vec![0f32; e]);
247        let mut gate_exps = vec![0f32; e * h * ff];
248        let mut up_exps = vec![0f32; e * h * ff];
249        let mut down_exps = vec![0f32; e * ff * h];
250        for ei in 0..e {
251            let base = format!("model.layers.{il}.mlp.experts.{ei}");
252            let g = take_transposed_any(loader, &[&format!("{base}.gate_proj.weight")])?.0;
253            let u = take_transposed_any(loader, &[&format!("{base}.up_proj.weight")])?.0;
254            let d = take_transposed_any(loader, &[&format!("{base}.down_proj.weight")])?.0;
255            let stride_in = h * ff;
256            let stride_out = ff * h;
257            gate_exps[ei * stride_in..(ei + 1) * stride_in].copy_from_slice(&g);
258            up_exps[ei * stride_in..(ei + 1) * stride_in].copy_from_slice(&u);
259            down_exps[ei * stride_out..(ei + 1) * stride_out].copy_from_slice(&d);
260        }
261        let (shared_gate, shared_up, shared_down) = if cfg.num_shared_experts.unwrap_or(0) > 0 {
262            let sg = take_transposed_any(
263                loader,
264                &[&format!(
265                    "model.layers.{il}.mlp.shared_experts.gate_proj.weight"
266                )],
267            )?
268            .0;
269            let su = take_transposed_any(
270                loader,
271                &[&format!(
272                    "model.layers.{il}.mlp.shared_experts.up_proj.weight"
273                )],
274            )?
275            .0;
276            let sd = take_transposed_any(
277                loader,
278                &[&format!(
279                    "model.layers.{il}.mlp.shared_experts.down_proj.weight"
280                )],
281            )?
282            .0;
283            (Some(sg), Some(su), Some(sd))
284        } else {
285            (None, None, None)
286        };
287        LayerFfn::Moe(MoeLayerWeights {
288            router,
289            expert_bias,
290            gate_exps,
291            up_exps,
292            down_exps,
293            shared_gate,
294            shared_up,
295            shared_down,
296        })
297    } else {
298        LayerFfn::Dense(DenseFfnWeights {
299            gate: take_transposed_any(loader, &[&p("mlp.gate_proj.weight")])?.0,
300            up: take_transposed_any(loader, &[&p("mlp.up_proj.weight")])?.0,
301            down: take_transposed_any(loader, &[&p("mlp.down_proj.weight")])?.0,
302        })
303    };
304
305    Ok(LayerWeights {
306        input_norm: loader.take(&p("input_layernorm.weight"))?.0,
307        post_attn_norm: loader.take(&p("post_attention_layernorm.weight"))?.0,
308        qkv,
309        q_norm,
310        k_norm,
311        o_proj,
312        ffn,
313    })
314}
315
316/// Register all tensors into `params` for graph compile.
317pub fn register_params(
318    cfg: &LLaDA2MoeConfig,
319    weights: &LLaDA2Weights,
320    params: &mut HashMap<String, Vec<f32>>,
321) {
322    params.insert("model.embed_tokens.weight".into(), weights.embed.clone());
323    params.insert("model.norm.weight".into(), weights.final_norm.clone());
324    params.insert("lm_head.weight".into(), weights.lm_head.clone());
325    let inv = crate::rope::inv_freq(cfg);
326    let (cos, sin) = crate::rope::build_rope_tables(cfg, &inv, cfg.max_position_embeddings);
327    params.insert("rope.cos".into(), cos);
328    params.insert("rope.sin".into(), sin);
329    for (il, layer) in weights.layers.iter().enumerate() {
330        let p = |t: &str| format!("model.layers.{il}.{t}");
331        params.insert(p("input_layernorm.weight"), layer.input_norm.clone());
332        params.insert(
333            p("post_attention_layernorm.weight"),
334            layer.post_attn_norm.clone(),
335        );
336        params.insert(p("self_attn.query_key_value.weight"), layer.qkv.clone());
337        params.insert(p("self_attn.dense.weight"), layer.o_proj.clone());
338        if let Some(q) = &layer.q_norm {
339            params.insert(p("self_attn.query_layernorm.weight"), q.clone());
340        }
341        if let Some(k) = &layer.k_norm {
342            params.insert(p("self_attn.key_layernorm.weight"), k.clone());
343        }
344        match &layer.ffn {
345            LayerFfn::Dense(d) => {
346                params.insert(p("mlp.gate_proj.weight"), d.gate.clone());
347                params.insert(p("mlp.up_proj.weight"), d.up.clone());
348                params.insert(p("mlp.down_proj.weight"), d.down.clone());
349            }
350            LayerFfn::Moe(m) => {
351                params.insert(p("mlp.gate.weight"), m.router.clone());
352                params.insert(p("mlp.gate.expert_bias"), m.expert_bias.clone());
353                params.insert(p("mlp.gate_exps.weight"), m.gate_exps.clone());
354                params.insert(p("mlp.up_exps.weight"), m.up_exps.clone());
355                params.insert(p("mlp.down_exps.weight"), m.down_exps.clone());
356                if let Some(w) = &m.shared_gate {
357                    params.insert(p("mlp.shared_experts.gate_proj.weight"), w.clone());
358                }
359                if let Some(w) = &m.shared_up {
360                    params.insert(p("mlp.shared_experts.up_proj.weight"), w.clone());
361                }
362                if let Some(w) = &m.shared_down {
363                    params.insert(p("mlp.shared_experts.down_proj.weight"), w.clone());
364                }
365            }
366        }
367    }
368}