use anyhow::{Context, Result, anyhow};
use crate::weight_loader::WeightLoader;
#[derive(Debug, Clone)]
pub struct MoeLayerWeights {
pub gate: Vec<f32>,
pub up: Vec<f32>,
pub down: Vec<f32>,
pub router: Vec<f32>,
pub num_experts: usize,
pub n_embd: usize,
pub n_ff: usize,
}
#[derive(Debug, Clone)]
pub struct MoeLayerKeys {
pub router: String,
pub gate: String,
pub up: String,
pub down: String,
}
impl MoeLayerKeys {
pub fn llama_cpp(layer_idx: usize) -> Self {
let p = format!("blk.{layer_idx}");
Self {
router: format!("{p}.ffn_gate_inp.weight"),
gate: format!("{p}.ffn_gate_exps.weight"),
up: format!("{p}.ffn_up_exps.weight"),
down: format!("{p}.ffn_down_exps.weight"),
}
}
pub fn hf_block_sparse(layer_idx: usize) -> Self {
let p = format!("model.layers.{layer_idx}.block_sparse_moe");
Self {
router: format!("{p}.gate.weight"),
gate: format!("{p}.experts.gate_proj.weight"),
up: format!("{p}.experts.up_proj.weight"),
down: format!("{p}.experts.down_proj.weight"),
}
}
}
pub fn load_expert_stack(
loader: &mut dyn WeightLoader,
key: &str,
num_experts: usize,
k: usize,
n: usize,
) -> Result<Vec<f32>> {
let (data, shape) = loader
.take(key)
.with_context(|| format!("MoE expert stack `{key}`"))?;
let expected = vec![num_experts, k, n];
if shape != expected {
return Err(anyhow!(
"MoE expert stack `{key}`: expected shape {expected:?}, got {shape:?}"
));
}
let expected_len = num_experts * k * n;
if data.len() != expected_len {
return Err(anyhow!(
"MoE expert stack `{key}`: shape {shape:?} declares \
{expected_len} elements but loader returned {}",
data.len()
));
}
Ok(data)
}
pub fn load_router(
loader: &mut dyn WeightLoader,
key: &str,
n_embd: usize,
num_experts: usize,
) -> Result<Vec<f32>> {
let (data, shape) = loader
.take(key)
.with_context(|| format!("MoE router `{key}`"))?;
let expected = vec![n_embd, num_experts];
if shape != expected {
return Err(anyhow!(
"MoE router `{key}`: expected shape {expected:?}, got {shape:?}"
));
}
if data.len() != n_embd * num_experts {
return Err(anyhow!(
"MoE router `{key}`: data len {} != n_embd*num_experts ({})",
data.len(),
n_embd * num_experts
));
}
Ok(data)
}
pub fn load_layer(
loader: &mut dyn WeightLoader,
keys: &MoeLayerKeys,
num_experts: usize,
n_embd: usize,
n_ff: usize,
) -> Result<MoeLayerWeights> {
let router = load_router(loader, &keys.router, n_embd, num_experts)?;
let gate = load_expert_stack(loader, &keys.gate, num_experts, n_embd, n_ff)?;
let up = load_expert_stack(loader, &keys.up, num_experts, n_embd, n_ff)?;
let down = load_expert_stack(loader, &keys.down, num_experts, n_ff, n_embd)?;
Ok(MoeLayerWeights {
gate,
up,
down,
router,
num_experts,
n_embd,
n_ff,
})
}
pub fn stack_expert_tensors(
per_expert: &[(Vec<f32>, Vec<usize>)],
) -> Result<(Vec<f32>, Vec<usize>)> {
let num_experts = per_expert.len();
if num_experts == 0 {
return Err(anyhow!("stack_expert_tensors: empty input"));
}
let first_shape = &per_expert[0].1;
if first_shape.len() != 2 {
return Err(anyhow!(
"stack_expert_tensors: first expert tensor must be rank-2, got {first_shape:?}"
));
}
let k = first_shape[0];
let n = first_shape[1];
let per = k * n;
let mut out = Vec::with_capacity(num_experts * per);
for (idx, (data, shape)) in per_expert.iter().enumerate() {
if shape.as_slice() != [k, n] {
return Err(anyhow!(
"stack_expert_tensors: expert {idx} shape {shape:?} != first expert shape {first_shape:?}"
));
}
if data.len() != per {
return Err(anyhow!(
"stack_expert_tensors: expert {idx} data len {} != {per}",
data.len()
));
}
out.extend_from_slice(data);
}
Ok((out, vec![num_experts, k, n]))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::weight_loader::WeightLoader;
use crate::weight_map::WeightMap;
use std::collections::HashMap;
struct MapLoader {
tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
}
impl WeightLoader for MapLoader {
fn len(&self) -> usize {
self.tensors.len()
}
fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
self.tensors
.remove(key)
.ok_or_else(|| anyhow!("missing weight: {key}"))
}
fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
self.take(key)
}
fn remaining_keys(&self) -> Vec<String> {
self.tensors.keys().cloned().collect()
}
}
fn synth_data(n: usize, seed: u64) -> Vec<f32> {
(0..n)
.map(|i| ((i as u64 + seed) % 7) as f32 * 0.01)
.collect()
}
#[test]
fn load_layer_round_trip() {
let num_experts = 4;
let n_embd = 8;
let n_ff = 16;
let keys = MoeLayerKeys::llama_cpp(0);
let mut tensors = HashMap::new();
tensors.insert(
keys.router.clone(),
(
synth_data(n_embd * num_experts, 1),
vec![n_embd, num_experts],
),
);
tensors.insert(
keys.gate.clone(),
(
synth_data(num_experts * n_embd * n_ff, 2),
vec![num_experts, n_embd, n_ff],
),
);
tensors.insert(
keys.up.clone(),
(
synth_data(num_experts * n_embd * n_ff, 3),
vec![num_experts, n_embd, n_ff],
),
);
tensors.insert(
keys.down.clone(),
(
synth_data(num_experts * n_ff * n_embd, 4),
vec![num_experts, n_ff, n_embd],
),
);
let mut loader = MapLoader { tensors };
let w = load_layer(&mut loader, &keys, num_experts, n_embd, n_ff).expect("load_layer");
assert_eq!(w.num_experts, num_experts);
assert_eq!(w.gate.len(), num_experts * n_embd * n_ff);
assert_eq!(w.up.len(), num_experts * n_embd * n_ff);
assert_eq!(w.down.len(), num_experts * n_ff * n_embd);
assert_eq!(w.router.len(), n_embd * num_experts);
}
#[test]
fn shape_mismatch_errors() {
let mut tensors = HashMap::new();
tensors.insert(
"blk.0.ffn_gate_exps.weight".into(),
(synth_data(16, 0), vec![8, 2]),
);
let mut loader = MapLoader { tensors };
let err = load_expert_stack(&mut loader, "blk.0.ffn_gate_exps.weight", 4, 8, 2)
.expect_err("should error on wrong shape");
assert!(format!("{err:#}").contains("expected shape"));
}
#[test]
fn stack_expert_tensors_basic() {
let per: Vec<(Vec<f32>, Vec<usize>)> =
(0..3).map(|i| (vec![i as f32; 6], vec![2, 3])).collect();
let (stacked, shape) = stack_expert_tensors(&per).expect("stack");
assert_eq!(shape, vec![3, 2, 3]);
assert_eq!(stacked.len(), 18);
assert_eq!(&stacked[..6], &[0.0; 6]);
assert_eq!(&stacked[6..12], &[1.0; 6]);
assert_eq!(&stacked[12..18], &[2.0; 6]);
}
#[test]
fn keys_use_llama_cpp_convention_by_default() {
let k = MoeLayerKeys::llama_cpp(5);
assert_eq!(k.router, "blk.5.ffn_gate_inp.weight");
assert_eq!(k.gate, "blk.5.ffn_gate_exps.weight");
assert_eq!(k.up, "blk.5.ffn_up_exps.weight");
assert_eq!(k.down, "blk.5.ffn_down_exps.weight");
}
#[allow(dead_code)]
fn _kept(_m: WeightMap) {}
}