use crate::riir::backend::cpu::cpu_matvec::{
bf16_matvec_cpu, dequant_matvec_4bit_bytes_cpu, CpuMatvecError,
};
use crate::riir::io::embedding::bf16_to_f32;
use crate::riir::io::expert_io::{ExpertFiles, ExpertIoError};
use crate::riir::moe::mlp_cpu::{shared_expert_swiglu_cpu, MlpForwardError};
use crate::riir::moe::moe_router::{noaux_tc_router_cpu, ExpertBuckets, MoeRouterError};
use crate::riir::variants::{Variant, GROUP_SIZE, VARIANT};
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum MoeForwardError {
#[error("hidden buffer length {got} != hidden_dim ({expected})")]
HiddenLen { got: usize, expected: usize },
#[error("output buffer length {got} != hidden_dim ({expected})")]
OutLen { got: usize, expected: usize },
#[error("missing tensor '{name}'")]
MissingTensor { name: String },
#[error(
"router-gate bias '{name}' size {got} bytes, expected {expected}"
)]
BiasSize {
name: String,
got: usize,
expected: usize,
},
#[error("4-bit matvec error in MoE: {0}")]
Matvec(#[from] CpuMatvecError),
#[error("router error in MoE: {0}")]
Router(#[from] MoeRouterError),
#[error("shared-expert MLP error in MoE: {0}")]
Mlp(#[from] MlpForwardError),
#[error("expert I/O error in MoE: {0}")]
Io(#[from] ExpertIoError),
}
pub fn deepseek_moe_cpu(
wf: &WeightFile,
expert_files: &ExpertFiles,
layer_idx: usize,
hidden: &[f32],
out: &mut [f32],
) -> Result<(), MoeForwardError> {
let v = VARIANT;
let hidden_dim = v.hidden_dim;
let num_experts = v.num_experts;
let k = v.num_experts_per_tok;
if hidden.len() != hidden_dim {
return Err(MoeForwardError::HiddenLen {
got: hidden.len(),
expected: hidden_dim,
});
}
if out.len() != hidden_dim {
return Err(MoeForwardError::OutLen {
got: out.len(),
expected: hidden_dim,
});
}
let gate_w_name =
format!("model.layers.{layer_idx}.mlp.gate.weight");
let bias_name = format!(
"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias"
);
let gate_w_bytes = wf
.tensor_bytes(&gate_w_name)
.ok_or_else(|| MoeForwardError::MissingTensor {
name: gate_w_name.clone(),
})?;
let expected_gate_bytes = num_experts * hidden_dim * 2;
if gate_w_bytes.len() != expected_gate_bytes {
return Err(MoeForwardError::BiasSize {
name: gate_w_name,
got: gate_w_bytes.len(),
expected: expected_gate_bytes,
});
}
let gate_w = bytemuck_u16(gate_w_bytes);
let mut gate_logits = vec![0.0f32; num_experts];
bf16_matvec_cpu(
gate_w,
hidden_dim,
num_experts,
hidden,
&mut gate_logits,
)?;
let bias_bytes = wf.tensor_bytes(&bias_name).ok_or_else(|| {
MoeForwardError::MissingTensor {
name: bias_name.clone(),
}
})?;
let expected_bias_bytes = num_experts * 2;
if bias_bytes.len() != expected_bias_bytes {
return Err(MoeForwardError::BiasSize {
name: bias_name,
got: bias_bytes.len(),
expected: expected_bias_bytes,
});
}
let bias_u16 = bytemuck_u16(bias_bytes);
let bias_f32: Vec<f32> =
bias_u16.iter().map(|&b| bf16_to_f32(b)).collect();
let mut indices = vec![0i32; k];
let mut weights = vec![0.0f32; k];
noaux_tc_router_cpu(
&mut gate_logits,
&bias_f32,
v.n_group,
v.topk_group,
k,
v.routed_scaling_factor,
&mut indices,
&mut weights,
)?;
out.fill(0.0);
let mut blob = vec![0u8; v.expert_size_4bit()];
let mut expert_out = vec![0.0f32; hidden_dim];
for slot in 0..k {
let expert_idx = indices[slot] as usize;
let w = weights[slot];
expert_files.read_expert(layer_idx, expert_idx, &mut blob)?;
run_packed_expert_swiglu(&v, &blob, hidden, &mut expert_out)?;
for i in 0..hidden_dim {
out[i] = w.mul_add(expert_out[i], out[i]);
}
}
let mut shared = vec![0.0f32; hidden_dim];
shared_expert_swiglu_cpu(wf, layer_idx, hidden, &mut shared)?;
for i in 0..hidden_dim {
out[i] += shared[i];
}
Ok(())
}
fn run_packed_expert_swiglu(
v: &Variant,
blob: &[u8],
hidden: &[f32],
out: &mut [f32],
) -> Result<(), MoeForwardError> {
let hidden_dim = v.hidden_dim;
let intermediate = v.moe_intermediate;
let w_bytes = v.expert_weight_bytes_4bit();
let s_bytes = v.expert_scale_bytes();
debug_assert_eq!(intermediate % GROUP_SIZE, 0);
debug_assert_eq!(hidden_dim % GROUP_SIZE, 0);
let gate_w = &blob[v.gate_w_off_4bit()..v.gate_w_off_4bit() + w_bytes];
let gate_s = &blob[v.gate_s_off_4bit()..v.gate_s_off_4bit() + s_bytes];
let gate_b = &blob[v.gate_b_off_4bit()..v.gate_b_off_4bit() + s_bytes];
let up_w = &blob[v.up_w_off_4bit()..v.up_w_off_4bit() + w_bytes];
let up_s = &blob[v.up_s_off_4bit()..v.up_s_off_4bit() + s_bytes];
let up_b = &blob[v.up_b_off_4bit()..v.up_b_off_4bit() + s_bytes];
let down_w = &blob[v.down_w_off_4bit()..v.down_w_off_4bit() + w_bytes];
let down_s = &blob[v.down_s_off_4bit()..v.down_s_off_4bit() + s_bytes];
let down_b = &blob[v.down_b_off_4bit()..v.down_b_off_4bit() + s_bytes];
let mut gate = vec![0.0f32; intermediate];
let mut up = vec![0.0f32; intermediate];
dequant_matvec_4bit_bytes_cpu(
gate_w, gate_s, gate_b, hidden_dim, intermediate, hidden, &mut gate,
)?;
dequant_matvec_4bit_bytes_cpu(
up_w, up_s, up_b, hidden_dim, intermediate, hidden, &mut up,
)?;
for i in 0..intermediate {
let g = gate[i];
let silu = g / (1.0 + (-g).exp());
gate[i] = silu * up[i];
}
dequant_matvec_4bit_bytes_cpu(
down_w, down_s, down_b, intermediate, hidden_dim, &gate, out,
)?;
Ok(())
}
pub fn moe_permute_fuse_cpu(
v: &Variant,
expert_blobs: &[&[u8]],
bucket_input: &[f32],
buckets: &ExpertBuckets,
out_sum: &mut [f32],
) -> Result<(), MoeForwardError> {
let hidden_dim = v.hidden_dim;
if buckets.expert_ids.len() != expert_blobs.len() {
return Err(MoeForwardError::MissingTensor {
name: format!(
"expert_blobs[len={}] does not match buckets.expert_ids[len={}]",
expert_blobs.len(),
buckets.expert_ids.len()
),
});
}
if buckets.offsets.len() != buckets.expert_ids.len() + 1 {
return Err(MoeForwardError::MissingTensor {
name: format!(
"buckets.offsets[len={}] must be expert_ids[len={}] + 1",
buckets.offsets.len(),
buckets.expert_ids.len()
),
});
}
let total_slots = *buckets.offsets.last().unwrap_or(&0) as usize;
if bucket_input.len() != total_slots * hidden_dim {
return Err(MoeForwardError::HiddenLen {
got: bucket_input.len(),
expected: total_slots * hidden_dim,
});
}
if buckets.token_idx.len() != total_slots
|| buckets.weights.len() != total_slots
{
return Err(MoeForwardError::OutLen {
got: buckets.token_idx.len(),
expected: total_slots,
});
}
let mut tmp_out = vec![0.0f32; hidden_dim];
for (b, &_expert_id) in buckets.expert_ids.iter().enumerate() {
let lo = buckets.offsets[b] as usize;
let hi = buckets.offsets[b + 1] as usize;
let blob = expert_blobs[b];
for s in lo..hi {
let x = &bucket_input[s * hidden_dim..(s + 1) * hidden_dim];
run_packed_expert_swiglu(v, blob, x, &mut tmp_out)?;
let t = buckets.token_idx[s] as usize;
let w = buckets.weights[s];
let out_token = &mut out_sum[t * hidden_dim..(t + 1) * hidden_dim];
for i in 0..hidden_dim {
out_token[i] = tmp_out[i].mul_add(w, out_token[i]);
}
}
}
Ok(())
}
fn bytemuck_u16(bytes: &[u8]) -> &[u16] {
static_assertions::assert_eq_size!(u16, [u8; 2]);
static_assertions::const_assert_eq!(std::mem::align_of::<u16>(), 2);
let (head, body, tail) = unsafe { bytes.align_to::<u16>() };
assert!(
head.is_empty() && tail.is_empty(),
"BF16 tensor not aligned to 2-byte boundary (head={}, tail={})",
head.len(),
tail.len()
);
body
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "model-cogito-v2-671b")]
#[test]
#[ignore = "needs Cogito-V2 weights and packed_experts/ on /Volumes/Temp Backup"]
fn moe_layer3_smoke() {
use std::path::Path;
let bin = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/artifacts/model_weights.bin",
);
let manifest = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/artifacts/model_weights.json",
);
let experts_dir = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/root",
);
let wf = WeightFile::open(bin, manifest).expect("open weights");
let ef = ExpertFiles::open(experts_dir).expect("open experts");
let v = VARIANT;
let mut hidden = vec![0.0f32; v.hidden_dim];
for (i, h) in hidden.iter_mut().enumerate() {
*h = ((i as f32) * 0.001).sin();
}
let mut out = vec![0.0f32; v.hidden_dim];
deepseek_moe_cpu(&wf, &ef, 3, &hidden, &mut out)
.expect("MoE layer 3");
assert!(
out.iter().all(|x| x.is_finite()),
"non-finite output at index {:?}",
out.iter().position(|x| !x.is_finite()),
);
let max_abs = out.iter().fold(0.0f32, |m, &x| m.max(x.abs()));
assert!(
max_abs > 0.0,
"all-zero MoE output — likely router or expert wiring bug"
);
assert!(
max_abs < 1e6,
"MoE output magnitude {max_abs} suspiciously large"
);
}
}