use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
use candle_core::quantized::GgmlDType;
use candle_core::{Device, Result as CandleResult};
use ferrum_kernels::backend::cpu::CpuBackend;
use ferrum_kernels::backend::{Backend, GgufQuantType};
use ferrum_kernels::Linear;
use ferrum_quantization::gguf::GgufFile;
use ferrum_quantization::{DenseLinear, QuantLinear};
use ferrum_types::{FerrumError, Result};
use crate::moe::router::RouterOutput;
pub static MOE_SYNC_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_SYNC_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_GEMV_GATE_UP_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_GEMV_GATE_UP_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_SILU_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_GEMV_DOWN_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_GEMV_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_SCALED_ADD_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_SCALED_ADD_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_COPY_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_COPY_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
fn moe_profile_enabled() -> bool {
std::env::var("FERRUM_MOE_PROFILE").is_ok()
}
pub struct ExpertStack<B: Backend> {
pub gate_up: Vec<Box<dyn Linear<B>>>,
pub down: Vec<Box<dyn Linear<B>>>,
pub gate_stacked: Option<B::QuantStore>,
pub up_stacked: Option<B::QuantStore>,
pub down_stacked: Option<B::QuantStore>,
}
impl<B: Backend> ExpertStack<B> {
pub fn from_dense_stacks(
gate_stack: &[f32],
up_stack: &[f32],
down_stack: &[f32],
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self> {
let gate_up_per_expert = expert_intermediate * hidden_size;
let down_per_expert = hidden_size * expert_intermediate;
check_size(
gate_stack.len(),
num_experts * gate_up_per_expert,
"gate_stack",
)?;
check_size(up_stack.len(), num_experts * gate_up_per_expert, "up_stack")?;
check_size(
down_stack.len(),
num_experts * down_per_expert,
"down_stack",
)?;
let mut gate_up = Vec::with_capacity(num_experts);
let mut down = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let g_off = e * gate_up_per_expert;
let g_slice = &gate_stack[g_off..g_off + gate_up_per_expert];
let u_slice = &up_stack[g_off..g_off + gate_up_per_expert];
let mut fused = Vec::with_capacity(2 * gate_up_per_expert);
fused.extend_from_slice(g_slice);
fused.extend_from_slice(u_slice);
gate_up.push(Box::new(DenseLinear::<B>::from_rows(
&fused,
2 * expert_intermediate,
hidden_size,
)) as Box<dyn Linear<B>>);
let d_off = e * down_per_expert;
let d_slice = &down_stack[d_off..d_off + down_per_expert];
down.push(Box::new(DenseLinear::<B>::from_rows(
d_slice,
hidden_size,
expert_intermediate,
)) as Box<dyn Linear<B>>);
}
Ok(Self {
gate_up,
down,
gate_stacked: None,
up_stacked: None,
down_stacked: None,
})
}
pub fn load_from_gguf(
gguf: &GgufFile,
layer_idx: usize,
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self> {
if let Some(quant) = Self::try_load_quantised(
gguf,
layer_idx,
num_experts,
hidden_size,
expert_intermediate,
)? {
if std::env::var("FERRUM_MOE_LOAD_TRACE").is_ok() {
eprintln!("[moe-load] layer {layer_idx} → quantised expert path");
}
return Ok(quant);
}
if std::env::var("FERRUM_MOE_LOAD_TRACE").is_ok() {
eprintln!("[moe-load] layer {layer_idx} → eager fp32 dense fallback ⚠");
}
let device = Device::Cpu;
let gate = read_dequant_flat(
gguf,
&format!("blk.{layer_idx}.ffn_gate_exps.weight"),
&device,
)?;
let up = read_dequant_flat(
gguf,
&format!("blk.{layer_idx}.ffn_up_exps.weight"),
&device,
)?;
let down = read_dequant_flat(
gguf,
&format!("blk.{layer_idx}.ffn_down_exps.weight"),
&device,
)?;
Self::from_dense_stacks(
&gate,
&up,
&down,
num_experts,
hidden_size,
expert_intermediate,
)
}
fn try_load_quantised(
gguf: &GgufFile,
layer_idx: usize,
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Option<Self>> {
let device = Device::Cpu;
let gate_name = format!("blk.{layer_idx}.ffn_gate_exps.weight");
let up_name = format!("blk.{layer_idx}.ffn_up_exps.weight");
let down_name = format!("blk.{layer_idx}.ffn_down_exps.weight");
let gate_kind = match quant_kind(gguf, &gate_name)? {
Some(k) => k,
None => return Ok(None),
};
let up_kind = match quant_kind(gguf, &up_name)? {
Some(k) => k,
None => return Ok(None),
};
let down_kind = match quant_kind(gguf, &down_name)? {
Some(k) => k,
None => return Ok(None),
};
let gate_bytes = gguf.tensor_byte_slice(&gate_name).ok_or_else(|| {
FerrumError::model(format!("MoE: tensor_byte_slice failed for '{gate_name}'"))
})?;
let up_bytes = gguf.tensor_byte_slice(&up_name).ok_or_else(|| {
FerrumError::model(format!("MoE: tensor_byte_slice failed for '{up_name}'"))
})?;
let down_bytes = gguf.tensor_byte_slice(&down_name).ok_or_else(|| {
FerrumError::model(format!("MoE: tensor_byte_slice failed for '{down_name}'"))
})?;
let _ = device;
let gate_per = block_bytes_for(
gate_kind,
expert_intermediate * hidden_size,
"ffn_gate_exps",
)?;
let up_per = block_bytes_for(up_kind, expert_intermediate * hidden_size, "ffn_up_exps")?;
let down_per = block_bytes_for(
down_kind,
hidden_size * expert_intermediate,
"ffn_down_exps",
)?;
check_size(
gate_bytes.len(),
num_experts * gate_per,
"ffn_gate_exps bytes",
)?;
check_size(up_bytes.len(), num_experts * up_per, "ffn_up_exps bytes")?;
check_size(
down_bytes.len(),
num_experts * down_per,
"ffn_down_exps bytes",
)?;
let gate_stacked = B::load_quant_experts(
gate_kind,
gate_bytes,
num_experts,
expert_intermediate,
hidden_size,
)
.ok();
let up_stacked = B::load_quant_experts(
up_kind,
up_bytes,
num_experts,
expert_intermediate,
hidden_size,
)
.ok();
let down_stacked = B::load_quant_experts(
down_kind,
down_bytes,
num_experts,
hidden_size,
expert_intermediate,
)
.ok();
let stacked_complete =
gate_stacked.is_some() && up_stacked.is_some() && down_stacked.is_some();
let (gate_up, down) = if stacked_complete {
(Vec::new(), Vec::new())
} else {
let mut gate_up: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
let mut down: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let g_slice = &gate_bytes[e * gate_per..(e + 1) * gate_per];
let u_slice = &up_bytes[e * up_per..(e + 1) * up_per];
let d_slice = &down_bytes[e * down_per..(e + 1) * down_per];
let parts: [(GgufQuantType, &[u8], usize); 2] = [
(gate_kind, g_slice, expert_intermediate),
(up_kind, u_slice, expert_intermediate),
];
let gate_up_e = match QuantLinear::<B>::from_gguf_fused(&parts, hidden_size) {
Ok(q) => q,
Err(_) => return Ok(None),
};
gate_up.push(Box::new(gate_up_e) as Box<dyn Linear<B>>);
let down_e = match QuantLinear::<B>::from_gguf_bytes(
down_kind,
d_slice,
hidden_size,
expert_intermediate,
) {
Ok(q) => q,
Err(_) => return Ok(None),
};
down.push(Box::new(down_e) as Box<dyn Linear<B>>);
}
(gate_up, down)
};
Ok(Some(Self {
gate_up,
down,
gate_stacked,
up_stacked,
down_stacked,
}))
}
pub fn open_and_load(
path: impl AsRef<Path>,
layer_idx: usize,
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self> {
let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
Self::load_from_gguf(
&gguf,
layer_idx,
num_experts,
hidden_size,
expert_intermediate,
)
}
pub fn num_experts(&self) -> usize {
debug_assert_eq!(
self.gate_up.len(),
self.down.len(),
"ExpertStack: gate_up and down disagree on expert count"
);
self.gate_up.len()
}
}
#[allow(clippy::too_many_arguments)]
pub fn moe_forward<B: Backend>(
ctx: &mut B::Context,
x: &B::Buffer,
router_logits: &B::Buffer,
out: &mut B::Buffer,
batch: usize,
hidden_size: usize,
expert_intermediate: usize,
num_experts: usize,
top_k: usize,
norm_topk_prob: bool,
experts: &ExpertStack<B>,
x_single: &mut B::Buffer,
acc_buf: &mut B::Buffer,
gate_up_buf: &mut B::Buffer,
silu_buf: &mut B::Buffer,
down_buf: &mut B::Buffer,
zero_hidden: &B::Buffer,
) -> Result<()> {
let n_experts = experts.num_experts();
if n_experts != num_experts {
return Err(FerrumError::model(format!(
"moe_forward: experts.num_experts() = {n_experts} != cfg.num_experts = {num_experts}"
)));
}
let prof = moe_profile_enabled();
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::sync(ctx);
if let Some(t) = t0 {
MOE_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_SYNC_CALLS.fetch_add(1, Ordering::Relaxed);
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
let logits_host = B::to_vec(router_logits, batch * num_experts);
let route_out =
crate::moe::router::route(&logits_host, batch, num_experts, top_k, norm_topk_prob);
if let Some(t) = t0 {
MOE_HOST_TOPK_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_HOST_TOPK_CALLS.fetch_add(1, Ordering::Relaxed);
}
for b in 0..batch {
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::copy_slice(ctx, x, b * hidden_size, x_single, 0, hidden_size);
B::copy_slice(ctx, zero_hidden, 0, acc_buf, 0, hidden_size);
if let Some(t) = t0 {
MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_COPY_CALLS.fetch_add(2, Ordering::Relaxed);
}
for k in 0..top_k {
let pair = b * top_k + k;
let expert_id = route_out.expert_ids[pair] as usize;
let weight = route_out.expert_weights[pair];
if expert_id >= num_experts {
return Err(FerrumError::model(format!(
"moe_forward: routed expert {expert_id} >= num_experts {num_experts}"
)));
}
let t0 = if prof {
B::sync(ctx);
Some(std::time::Instant::now())
} else {
None
};
experts.gate_up[expert_id].forward(ctx, x_single, gate_up_buf, 1);
if let Some(t) = t0 {
B::sync(ctx);
MOE_GEMV_GATE_UP_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_GEMV_GATE_UP_CALLS.fetch_add(1, Ordering::Relaxed);
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::fused_silu_mul_split(ctx, gate_up_buf, silu_buf, 1, expert_intermediate);
if let Some(t) = t0 {
B::sync(ctx);
MOE_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_SILU_CALLS.fetch_add(1, Ordering::Relaxed);
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
experts.down[expert_id].forward(ctx, silu_buf, down_buf, 1);
if let Some(t) = t0 {
B::sync(ctx);
MOE_GEMV_DOWN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_GEMV_DOWN_CALLS.fetch_add(1, Ordering::Relaxed);
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::scaled_add_inplace(ctx, acc_buf, down_buf, weight, hidden_size);
if let Some(t) = t0 {
B::sync(ctx);
MOE_SCALED_ADD_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_SCALED_ADD_CALLS.fetch_add(1, Ordering::Relaxed);
}
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::copy_slice(ctx, acc_buf, 0, out, b * hidden_size, hidden_size);
if let Some(t) = t0 {
MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_COPY_CALLS.fetch_add(1, Ordering::Relaxed);
}
}
Ok(())
}
pub fn moe_forward_cpu(
x: &[f32],
batch: usize,
hidden_size: usize,
expert_intermediate: usize,
top_k: usize,
router: &RouterOutput,
experts: &ExpertStack<CpuBackend>,
out: &mut Vec<f32>,
) -> Result<()> {
let n_experts = experts.num_experts();
if x.len() != batch * hidden_size {
return Err(FerrumError::model(format!(
"moe_forward_cpu: x len {} doesn't match batch*hidden = {}*{} = {}",
x.len(),
batch,
hidden_size,
batch * hidden_size
)));
}
if router.expert_ids.len() != batch * top_k {
return Err(FerrumError::model(format!(
"moe_forward_cpu: router has {} expert_ids but expected batch*top_k = {}*{} = {}",
router.expert_ids.len(),
batch,
top_k,
batch * top_k
)));
}
out.clear();
out.resize(batch * hidden_size, 0.0);
let mut ctx = <CpuBackend as Backend>::new_context();
let mut x_b: Vec<f32> = vec![0.0; hidden_size];
let mut gate_up_buf: Vec<f32> = vec![0.0; 2 * expert_intermediate];
let mut silu_mul_buf: Vec<f32> = vec![0.0; expert_intermediate];
let mut down_out: Vec<f32> = vec![0.0; hidden_size];
for b in 0..batch {
x_b.copy_from_slice(&x[b * hidden_size..(b + 1) * hidden_size]);
for k in 0..top_k {
let pair_idx = b * top_k + k;
let expert_id = router.expert_ids[pair_idx] as usize;
let weight = router.expert_weights[pair_idx];
if expert_id >= n_experts {
return Err(FerrumError::model(format!(
"moe_forward_cpu: router selected expert {expert_id} >= num_experts {n_experts}"
)));
}
experts.gate_up[expert_id].forward(&mut ctx, &x_b, &mut gate_up_buf, 1);
<CpuBackend as Backend>::fused_silu_mul_split(
&mut ctx,
&gate_up_buf,
&mut silu_mul_buf,
1,
expert_intermediate,
);
experts.down[expert_id].forward(&mut ctx, &silu_mul_buf, &mut down_out, 1);
let out_row = &mut out[b * hidden_size..(b + 1) * hidden_size];
for (o, d) in out_row.iter_mut().zip(down_out.iter()) {
*o += weight * *d;
}
}
}
Ok(())
}
fn check_size(actual: usize, expected: usize, label: &str) -> Result<()> {
if actual != expected {
return Err(FerrumError::model(format!(
"ExpertStack: {label} size mismatch (got {actual}, expected {expected})"
)));
}
Ok(())
}
fn quant_kind(gguf: &GgufFile, name: &str) -> Result<Option<GgufQuantType>> {
let info = gguf.tensor_info(name).ok_or_else(|| {
FerrumError::model(format!("ExpertStack: tensor info missing for '{name}'"))
})?;
Ok(match info.ggml_dtype {
GgmlDType::Q4K => Some(GgufQuantType::Q4K),
GgmlDType::Q6K => Some(GgufQuantType::Q6K),
_ => None,
})
}
fn block_bytes_for(kind: GgufQuantType, n_elems: usize, label: &str) -> Result<usize> {
const QK_K: usize = 256;
if n_elems % QK_K != 0 {
return Err(FerrumError::model(format!(
"ExpertStack {label}: per-expert element count {n_elems} not a multiple of {QK_K}"
)));
}
let block_bytes = match kind {
GgufQuantType::Q4K => 144,
GgufQuantType::Q6K => 210,
other => {
return Err(FerrumError::model(format!(
"ExpertStack {label}: unsupported k-quant flavour {other:?}"
)))
}
};
Ok((n_elems / QK_K) * block_bytes)
}
fn read_dequant_flat(gguf: &GgufFile, name: &str, device: &Device) -> Result<Vec<f32>> {
let qt = gguf.read_tensor(name, device).map_err(candle_to_ferrum)?;
let dense = qt.dequantize(device).map_err(candle_to_ferrum)?;
let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
flat.to_vec1::<f32>().map_err(candle_to_ferrum)
}
fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
FerrumError::model(format!("candle: {e}"))
}
#[allow(dead_code)]
type _CandleResult<T> = CandleResult<T>;