1use anyhow::{Context, Result, anyhow};
30
31use crate::weight_loader::WeightLoader;
32
33#[derive(Debug, Clone)]
40pub struct MoeLayerWeights {
41 pub gate: Vec<f32>,
42 pub up: Vec<f32>,
43 pub down: Vec<f32>,
44 pub router: Vec<f32>,
45 pub num_experts: usize,
46 pub n_embd: usize,
47 pub n_ff: usize,
48}
49
50#[derive(Debug, Clone)]
58pub struct MoeLayerKeys {
59 pub router: String,
60 pub gate: String,
61 pub up: String,
62 pub down: String,
63}
64
65impl MoeLayerKeys {
66 pub fn llama_cpp(layer_idx: usize) -> Self {
68 let p = format!("blk.{layer_idx}");
69 Self {
70 router: format!("{p}.ffn_gate_inp.weight"),
71 gate: format!("{p}.ffn_gate_exps.weight"),
72 up: format!("{p}.ffn_up_exps.weight"),
73 down: format!("{p}.ffn_down_exps.weight"),
74 }
75 }
76
77 pub fn hf_block_sparse(layer_idx: usize) -> Self {
83 let p = format!("model.layers.{layer_idx}.block_sparse_moe");
84 Self {
85 router: format!("{p}.gate.weight"),
86 gate: format!("{p}.experts.gate_proj.weight"),
90 up: format!("{p}.experts.up_proj.weight"),
91 down: format!("{p}.experts.down_proj.weight"),
92 }
93 }
94}
95
96pub fn load_expert_stack(
99 loader: &mut dyn WeightLoader,
100 key: &str,
101 num_experts: usize,
102 k: usize,
103 n: usize,
104) -> Result<Vec<f32>> {
105 let (data, shape) = loader
106 .take(key)
107 .with_context(|| format!("MoE expert stack `{key}`"))?;
108 let expected = vec![num_experts, k, n];
109 if shape != expected {
110 return Err(anyhow!(
111 "MoE expert stack `{key}`: expected shape {expected:?}, got {shape:?}"
112 ));
113 }
114 let expected_len = num_experts * k * n;
115 if data.len() != expected_len {
116 return Err(anyhow!(
117 "MoE expert stack `{key}`: shape {shape:?} declares \
118 {expected_len} elements but loader returned {}",
119 data.len()
120 ));
121 }
122 Ok(data)
123}
124
125pub fn load_router(
127 loader: &mut dyn WeightLoader,
128 key: &str,
129 n_embd: usize,
130 num_experts: usize,
131) -> Result<Vec<f32>> {
132 let (data, shape) = loader
133 .take(key)
134 .with_context(|| format!("MoE router `{key}`"))?;
135 let expected = vec![n_embd, num_experts];
136 if shape != expected {
137 return Err(anyhow!(
138 "MoE router `{key}`: expected shape {expected:?}, got {shape:?}"
139 ));
140 }
141 if data.len() != n_embd * num_experts {
142 return Err(anyhow!(
143 "MoE router `{key}`: data len {} != n_embd*num_experts ({})",
144 data.len(),
145 n_embd * num_experts
146 ));
147 }
148 Ok(data)
149}
150
151pub fn load_layer(
153 loader: &mut dyn WeightLoader,
154 keys: &MoeLayerKeys,
155 num_experts: usize,
156 n_embd: usize,
157 n_ff: usize,
158) -> Result<MoeLayerWeights> {
159 let router = load_router(loader, &keys.router, n_embd, num_experts)?;
160 let gate = load_expert_stack(loader, &keys.gate, num_experts, n_embd, n_ff)?;
161 let up = load_expert_stack(loader, &keys.up, num_experts, n_embd, n_ff)?;
162 let down = load_expert_stack(loader, &keys.down, num_experts, n_ff, n_embd)?;
163 Ok(MoeLayerWeights {
164 gate,
165 up,
166 down,
167 router,
168 num_experts,
169 n_embd,
170 n_ff,
171 })
172}
173
174pub fn stack_expert_tensors(
179 per_expert: &[(Vec<f32>, Vec<usize>)],
180) -> Result<(Vec<f32>, Vec<usize>)> {
181 let num_experts = per_expert.len();
182 if num_experts == 0 {
183 return Err(anyhow!("stack_expert_tensors: empty input"));
184 }
185 let first_shape = &per_expert[0].1;
186 if first_shape.len() != 2 {
187 return Err(anyhow!(
188 "stack_expert_tensors: first expert tensor must be rank-2, got {first_shape:?}"
189 ));
190 }
191 let k = first_shape[0];
192 let n = first_shape[1];
193 let per = k * n;
194 let mut out = Vec::with_capacity(num_experts * per);
195 for (idx, (data, shape)) in per_expert.iter().enumerate() {
196 if shape.as_slice() != [k, n] {
197 return Err(anyhow!(
198 "stack_expert_tensors: expert {idx} shape {shape:?} != first expert shape {first_shape:?}"
199 ));
200 }
201 if data.len() != per {
202 return Err(anyhow!(
203 "stack_expert_tensors: expert {idx} data len {} != {per}",
204 data.len()
205 ));
206 }
207 out.extend_from_slice(data);
208 }
209 Ok((out, vec![num_experts, k, n]))
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::weight_loader::WeightLoader;
216 use crate::weight_map::WeightMap;
217 use std::collections::HashMap;
218
219 struct MapLoader {
222 tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
223 }
224
225 impl WeightLoader for MapLoader {
226 fn len(&self) -> usize {
227 self.tensors.len()
228 }
229 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
230 self.tensors
231 .remove(key)
232 .ok_or_else(|| anyhow!("missing weight: {key}"))
233 }
234 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
235 self.take(key)
236 }
237 fn remaining_keys(&self) -> Vec<String> {
238 self.tensors.keys().cloned().collect()
239 }
240 }
241
242 fn synth_data(n: usize, seed: u64) -> Vec<f32> {
243 (0..n)
244 .map(|i| ((i as u64 + seed) % 7) as f32 * 0.01)
245 .collect()
246 }
247
248 #[test]
249 fn load_layer_round_trip() {
250 let num_experts = 4;
251 let n_embd = 8;
252 let n_ff = 16;
253 let keys = MoeLayerKeys::llama_cpp(0);
254
255 let mut tensors = HashMap::new();
256 tensors.insert(
257 keys.router.clone(),
258 (
259 synth_data(n_embd * num_experts, 1),
260 vec![n_embd, num_experts],
261 ),
262 );
263 tensors.insert(
264 keys.gate.clone(),
265 (
266 synth_data(num_experts * n_embd * n_ff, 2),
267 vec![num_experts, n_embd, n_ff],
268 ),
269 );
270 tensors.insert(
271 keys.up.clone(),
272 (
273 synth_data(num_experts * n_embd * n_ff, 3),
274 vec![num_experts, n_embd, n_ff],
275 ),
276 );
277 tensors.insert(
278 keys.down.clone(),
279 (
280 synth_data(num_experts * n_ff * n_embd, 4),
281 vec![num_experts, n_ff, n_embd],
282 ),
283 );
284
285 let mut loader = MapLoader { tensors };
286 let w = load_layer(&mut loader, &keys, num_experts, n_embd, n_ff).expect("load_layer");
287 assert_eq!(w.num_experts, num_experts);
288 assert_eq!(w.gate.len(), num_experts * n_embd * n_ff);
289 assert_eq!(w.up.len(), num_experts * n_embd * n_ff);
290 assert_eq!(w.down.len(), num_experts * n_ff * n_embd);
291 assert_eq!(w.router.len(), n_embd * num_experts);
292 }
293
294 #[test]
295 fn shape_mismatch_errors() {
296 let mut tensors = HashMap::new();
297 tensors.insert(
299 "blk.0.ffn_gate_exps.weight".into(),
300 (synth_data(16, 0), vec![8, 2]),
301 );
302 let mut loader = MapLoader { tensors };
303 let err = load_expert_stack(&mut loader, "blk.0.ffn_gate_exps.weight", 4, 8, 2)
304 .expect_err("should error on wrong shape");
305 assert!(format!("{err:#}").contains("expected shape"));
306 }
307
308 #[test]
309 fn stack_expert_tensors_basic() {
310 let per: Vec<(Vec<f32>, Vec<usize>)> =
311 (0..3).map(|i| (vec![i as f32; 6], vec![2, 3])).collect();
312 let (stacked, shape) = stack_expert_tensors(&per).expect("stack");
313 assert_eq!(shape, vec![3, 2, 3]);
314 assert_eq!(stacked.len(), 18);
315 assert_eq!(&stacked[..6], &[0.0; 6]);
316 assert_eq!(&stacked[6..12], &[1.0; 6]);
317 assert_eq!(&stacked[12..18], &[2.0; 6]);
318 }
319
320 #[test]
321 fn keys_use_llama_cpp_convention_by_default() {
322 let k = MoeLayerKeys::llama_cpp(5);
323 assert_eq!(k.router, "blk.5.ffn_gate_inp.weight");
324 assert_eq!(k.gate, "blk.5.ffn_gate_exps.weight");
325 assert_eq!(k.up, "blk.5.ffn_up_exps.weight");
326 assert_eq!(k.down, "blk.5.ffn_down_exps.weight");
327 }
328
329 #[allow(dead_code)]
331 fn _kept(_m: WeightMap) {}
332}