Skip to main content

rlx_models_core/
moe_weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! GGUF MoE expert-stack loader.
17//!
18//! Companion to `rlx_flow::blocks::MoeFfnStage`. Loads the per-layer
19//! `ffn_{gate,up,down}_exps.weight` stacked tensors that llama.cpp's
20//! GGUF converters ship for Mixtral / Qwen3-MoE / Gemma 4 MoE / etc.,
21//! and validates the `[num_experts, k, n]` shape contract the MoE
22//! block expects.
23//!
24//! Per-family code paths (e.g. `rlx-qwen35::weights`) keep their own
25//! loaders that also handle GGUF K-quant packed slabs; this module is
26//! the **f32-dequant generic** path that's portable across families
27//! that don't (yet) need packed routing.
28
29use anyhow::{Context, Result, anyhow};
30
31use crate::weight_loader::WeightLoader;
32
33/// One layer's stacked MoE FFN tensors. Shapes match the contract of
34/// [`rlx_flow::blocks::MoeFfnStage`] (and the underlying
35/// `Op::GroupedMatMul`):
36///
37/// * `gate` / `up`: `[num_experts, n_embd, n_ff]`
38/// * `down`:       `[num_experts, n_ff, n_embd]`
39#[derive(Debug, Clone)]
40pub struct MoeLayerWeights {
41    pub gate: Vec<f32>,
42    pub up: Vec<f32>,
43    pub down: Vec<f32>,
44    pub router: Vec<f32>,
45    pub num_experts: usize,
46    pub n_embd: usize,
47    pub n_ff: usize,
48}
49
50/// GGUF-tensor naming for one MoE layer. Defaults follow llama.cpp's
51/// `qwen2moe` / `qwen3moe` / `gemma4moe` converters:
52///
53/// * router: `blk.{layer}.ffn_gate_inp.weight`
54/// * gate:   `blk.{layer}.ffn_gate_exps.weight`
55/// * up:     `blk.{layer}.ffn_up_exps.weight`
56/// * down:   `blk.{layer}.ffn_down_exps.weight`
57#[derive(Debug, Clone)]
58pub struct MoeLayerKeys {
59    pub router: String,
60    pub gate: String,
61    pub up: String,
62    pub down: String,
63}
64
65impl MoeLayerKeys {
66    /// llama.cpp default convention.
67    pub fn llama_cpp(layer_idx: usize) -> Self {
68        let p = format!("blk.{layer_idx}");
69        Self {
70            router: format!("{p}.ffn_gate_inp.weight"),
71            gate: format!("{p}.ffn_gate_exps.weight"),
72            up: format!("{p}.ffn_up_exps.weight"),
73            down: format!("{p}.ffn_down_exps.weight"),
74        }
75    }
76
77    /// HuggingFace convention (`model.layers.{i}.block_sparse_moe.*` /
78    /// `mlp.experts.*`). When the loader doesn't carry an MoE-specific
79    /// HF→GGUF tensor-name resolver, callers usually want
80    /// [`Self::llama_cpp`] instead since GGUF-on-disk uses the llama.cpp
81    /// names.
82    pub fn hf_block_sparse(layer_idx: usize) -> Self {
83        let p = format!("model.layers.{layer_idx}.block_sparse_moe");
84        Self {
85            router: format!("{p}.gate.weight"),
86            // HF stores per-expert separately — this loader expects the
87            // stacked variant. Callers who only have HF tensors should
88            // pre-stack them with `stack_expert_tensors` first.
89            gate: format!("{p}.experts.gate_proj.weight"),
90            up: format!("{p}.experts.up_proj.weight"),
91            down: format!("{p}.experts.down_proj.weight"),
92        }
93    }
94}
95
96/// Load `[num_experts, k, n]` f32 expert stack from a loader, verifying
97/// shape.
98pub fn load_expert_stack(
99    loader: &mut dyn WeightLoader,
100    key: &str,
101    num_experts: usize,
102    k: usize,
103    n: usize,
104) -> Result<Vec<f32>> {
105    let (data, shape) = loader
106        .take(key)
107        .with_context(|| format!("MoE expert stack `{key}`"))?;
108    let expected = vec![num_experts, k, n];
109    if shape != expected {
110        return Err(anyhow!(
111            "MoE expert stack `{key}`: expected shape {expected:?}, got {shape:?}"
112        ));
113    }
114    let expected_len = num_experts * k * n;
115    if data.len() != expected_len {
116        return Err(anyhow!(
117            "MoE expert stack `{key}`: shape {shape:?} declares \
118             {expected_len} elements but loader returned {}",
119            data.len()
120        ));
121    }
122    Ok(data)
123}
124
125/// Load router weight `[n_embd, num_experts]` f32.
126pub fn load_router(
127    loader: &mut dyn WeightLoader,
128    key: &str,
129    n_embd: usize,
130    num_experts: usize,
131) -> Result<Vec<f32>> {
132    let (data, shape) = loader
133        .take(key)
134        .with_context(|| format!("MoE router `{key}`"))?;
135    let expected = vec![n_embd, num_experts];
136    if shape != expected {
137        return Err(anyhow!(
138            "MoE router `{key}`: expected shape {expected:?}, got {shape:?}"
139        ));
140    }
141    if data.len() != n_embd * num_experts {
142        return Err(anyhow!(
143            "MoE router `{key}`: data len {} != n_embd*num_experts ({})",
144            data.len(),
145            n_embd * num_experts
146        ));
147    }
148    Ok(data)
149}
150
151/// Convenience: load all 4 tensors for one MoE layer at once.
152pub fn load_layer(
153    loader: &mut dyn WeightLoader,
154    keys: &MoeLayerKeys,
155    num_experts: usize,
156    n_embd: usize,
157    n_ff: usize,
158) -> Result<MoeLayerWeights> {
159    let router = load_router(loader, &keys.router, n_embd, num_experts)?;
160    let gate = load_expert_stack(loader, &keys.gate, num_experts, n_embd, n_ff)?;
161    let up = load_expert_stack(loader, &keys.up, num_experts, n_embd, n_ff)?;
162    let down = load_expert_stack(loader, &keys.down, num_experts, n_ff, n_embd)?;
163    Ok(MoeLayerWeights {
164        gate,
165        up,
166        down,
167        router,
168        num_experts,
169        n_embd,
170        n_ff,
171    })
172}
173
174/// Stack `num_experts` rank-2 per-expert tensors into one contiguous
175/// `[num_experts, k, n]` slab in GroupedMatMul layout (expert dim
176/// outermost). Used by HF-style checkpoints that ship per-expert
177/// tensors separately and need to be packed for the MoE block.
178pub fn stack_expert_tensors(
179    per_expert: &[(Vec<f32>, Vec<usize>)],
180) -> Result<(Vec<f32>, Vec<usize>)> {
181    let num_experts = per_expert.len();
182    if num_experts == 0 {
183        return Err(anyhow!("stack_expert_tensors: empty input"));
184    }
185    let first_shape = &per_expert[0].1;
186    if first_shape.len() != 2 {
187        return Err(anyhow!(
188            "stack_expert_tensors: first expert tensor must be rank-2, got {first_shape:?}"
189        ));
190    }
191    let k = first_shape[0];
192    let n = first_shape[1];
193    let per = k * n;
194    let mut out = Vec::with_capacity(num_experts * per);
195    for (idx, (data, shape)) in per_expert.iter().enumerate() {
196        if shape.as_slice() != [k, n] {
197            return Err(anyhow!(
198                "stack_expert_tensors: expert {idx} shape {shape:?} != first expert shape {first_shape:?}"
199            ));
200        }
201        if data.len() != per {
202            return Err(anyhow!(
203                "stack_expert_tensors: expert {idx} data len {} != {per}",
204                data.len()
205            ));
206        }
207        out.extend_from_slice(data);
208    }
209    Ok((out, vec![num_experts, k, n]))
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::weight_loader::WeightLoader;
216    use crate::weight_map::WeightMap;
217    use std::collections::HashMap;
218
219    /// In-memory `WeightLoader` for tests — just a HashMap of
220    /// `(name, (data, shape))`.
221    struct MapLoader {
222        tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
223    }
224
225    impl WeightLoader for MapLoader {
226        fn len(&self) -> usize {
227            self.tensors.len()
228        }
229        fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
230            self.tensors
231                .remove(key)
232                .ok_or_else(|| anyhow!("missing weight: {key}"))
233        }
234        fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
235            self.take(key)
236        }
237        fn remaining_keys(&self) -> Vec<String> {
238            self.tensors.keys().cloned().collect()
239        }
240    }
241
242    fn synth_data(n: usize, seed: u64) -> Vec<f32> {
243        (0..n)
244            .map(|i| ((i as u64 + seed) % 7) as f32 * 0.01)
245            .collect()
246    }
247
248    #[test]
249    fn load_layer_round_trip() {
250        let num_experts = 4;
251        let n_embd = 8;
252        let n_ff = 16;
253        let keys = MoeLayerKeys::llama_cpp(0);
254
255        let mut tensors = HashMap::new();
256        tensors.insert(
257            keys.router.clone(),
258            (
259                synth_data(n_embd * num_experts, 1),
260                vec![n_embd, num_experts],
261            ),
262        );
263        tensors.insert(
264            keys.gate.clone(),
265            (
266                synth_data(num_experts * n_embd * n_ff, 2),
267                vec![num_experts, n_embd, n_ff],
268            ),
269        );
270        tensors.insert(
271            keys.up.clone(),
272            (
273                synth_data(num_experts * n_embd * n_ff, 3),
274                vec![num_experts, n_embd, n_ff],
275            ),
276        );
277        tensors.insert(
278            keys.down.clone(),
279            (
280                synth_data(num_experts * n_ff * n_embd, 4),
281                vec![num_experts, n_ff, n_embd],
282            ),
283        );
284
285        let mut loader = MapLoader { tensors };
286        let w = load_layer(&mut loader, &keys, num_experts, n_embd, n_ff).expect("load_layer");
287        assert_eq!(w.num_experts, num_experts);
288        assert_eq!(w.gate.len(), num_experts * n_embd * n_ff);
289        assert_eq!(w.up.len(), num_experts * n_embd * n_ff);
290        assert_eq!(w.down.len(), num_experts * n_ff * n_embd);
291        assert_eq!(w.router.len(), n_embd * num_experts);
292    }
293
294    #[test]
295    fn shape_mismatch_errors() {
296        let mut tensors = HashMap::new();
297        // Wrong shape: missing num_experts dim
298        tensors.insert(
299            "blk.0.ffn_gate_exps.weight".into(),
300            (synth_data(16, 0), vec![8, 2]),
301        );
302        let mut loader = MapLoader { tensors };
303        let err = load_expert_stack(&mut loader, "blk.0.ffn_gate_exps.weight", 4, 8, 2)
304            .expect_err("should error on wrong shape");
305        assert!(format!("{err:#}").contains("expected shape"));
306    }
307
308    #[test]
309    fn stack_expert_tensors_basic() {
310        let per: Vec<(Vec<f32>, Vec<usize>)> =
311            (0..3).map(|i| (vec![i as f32; 6], vec![2, 3])).collect();
312        let (stacked, shape) = stack_expert_tensors(&per).expect("stack");
313        assert_eq!(shape, vec![3, 2, 3]);
314        assert_eq!(stacked.len(), 18);
315        assert_eq!(&stacked[..6], &[0.0; 6]);
316        assert_eq!(&stacked[6..12], &[1.0; 6]);
317        assert_eq!(&stacked[12..18], &[2.0; 6]);
318    }
319
320    #[test]
321    fn keys_use_llama_cpp_convention_by_default() {
322        let k = MoeLayerKeys::llama_cpp(5);
323        assert_eq!(k.router, "blk.5.ffn_gate_inp.weight");
324        assert_eq!(k.gate, "blk.5.ffn_gate_exps.weight");
325        assert_eq!(k.up, "blk.5.ffn_up_exps.weight");
326        assert_eq!(k.down, "blk.5.ffn_down_exps.weight");
327    }
328
329    // Quiets the unused-import warning when only one test references it.
330    #[allow(dead_code)]
331    fn _kept(_m: WeightMap) {}
332}