use std::path::Path;
use std::sync::{
atomic::{AtomicU64, Ordering},
OnceLock,
};
use candle_core::quantized::GgmlDType;
use candle_core::{Device, Result as CandleResult};
use ferrum_kernels::backend::cpu::CpuBackend;
use ferrum_kernels::backend::{
Backend, BackendMoeFused, BackendPagedKv, BackendQuantGguf, BackendQuantMarlin, GgufQuantType,
LlmBackend, QuantLlmBackend,
};
use ferrum_kernels::{Linear, StackedExpertGgufLinear};
use ferrum_quantization::gguf::GgufFile;
use ferrum_quantization::{DenseLinear, QuantLinear};
use ferrum_types::{FerrumError, Result};
use crate::moe::router::RouterOutput;
pub static MOE_SYNC_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_SYNC_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_GEMV_GATE_UP_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_GEMV_GATE_UP_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_SILU_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_GEMV_DOWN_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_GEMV_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_SCALED_ADD_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_SCALED_ADD_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_COPY_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_COPY_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_SYNC_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_D2H_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_ROUTE_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_PLAN_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_GATHER_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_GEMM1_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_SILU_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_GEMM3_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_COMBINE_US: AtomicU64 = AtomicU64::new(0);
pub static MOE_BUCKET_LAYER_CALLS: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, PartialEq, Eq)]
struct MoeDispatchRuntimeConfig {
moe_profile: bool,
decode_op_profile: bool,
vllm_moe_zero_ws: bool,
vllm_moe_pair_ids: bool,
moe_load_trace: bool,
moe_block_size: Option<usize>,
moe_large_m_block_size: Option<usize>,
moe_large_m_min_pairs: usize,
vllm_moe: bool,
moe_host_route: bool,
}
impl Default for MoeDispatchRuntimeConfig {
fn default() -> Self {
Self {
moe_profile: false,
decode_op_profile: false,
vllm_moe_zero_ws: false,
vllm_moe_pair_ids: false,
moe_load_trace: false,
moe_block_size: None,
moe_large_m_block_size: None,
moe_large_m_min_pairs: 1024,
vllm_moe: false,
moe_host_route: false,
}
}
}
impl MoeDispatchRuntimeConfig {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut config = Self::default();
for (name, value) in vars {
let value = value.as_ref();
match name.as_ref() {
"FERRUM_MOE_PROFILE" => config.moe_profile = true,
"FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
"FERRUM_VLLM_MOE_ZERO_WS" => config.vllm_moe_zero_ws = value == "1",
"FERRUM_VLLM_MOE_PAIR_IDS" => config.vllm_moe_pair_ids = value == "1",
"FERRUM_MOE_LOAD_TRACE" => config.moe_load_trace = true,
"FERRUM_MOE_BLOCK_SIZE" => {
config.moe_block_size = parse_moe_block_size_value(value);
}
"FERRUM_MOE_LARGE_M_BLOCK_SIZE" => {
config.moe_large_m_block_size = parse_moe_block_size_value(value);
}
"FERRUM_MOE_LARGE_M_MIN_PAIRS" => {
config.moe_large_m_min_pairs = value.parse::<usize>().unwrap_or(1024);
}
"FERRUM_VLLM_MOE" => config.vllm_moe = value == "1",
"FERRUM_MOE_HOST_ROUTE" => config.moe_host_route = value == "1",
_ => {}
}
}
config
}
}
fn parse_moe_block_size_value(value: &str) -> Option<usize> {
value
.parse::<usize>()
.ok()
.filter(|bs| matches!(*bs, 8 | 16 | 32 | 48 | 64))
}
fn moe_dispatch_runtime_config() -> &'static MoeDispatchRuntimeConfig {
static CONFIG: OnceLock<MoeDispatchRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(MoeDispatchRuntimeConfig::from_env)
}
fn moe_profile_enabled() -> bool {
moe_dispatch_runtime_config().moe_profile
}
pub struct ExpertStack<B: QuantLlmBackend + BackendMoeFused> {
pub gate_up: Vec<Box<dyn Linear<B>>>,
pub down: Vec<Box<dyn Linear<B>>>,
pub gate_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
pub up_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
pub down_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
pub gate_up_marlin_stack: Option<std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>>,
pub down_marlin_stack: Option<std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>>,
}
impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B> {
pub fn gate_up_stacked_store(
&self,
_expert_idx: usize,
) -> Option<&std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>> {
self.gate_up_marlin_stack.as_ref()
}
pub fn down_stacked_store(
&self,
_expert_idx: usize,
) -> Option<&std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>> {
self.down_marlin_stack.as_ref()
}
pub fn gemv_gate(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
out: &mut B::Buffer,
top_k: usize,
) -> Result<()> {
let weight = self.gate_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_gate: gate_stacked not loaded")
})?;
weight.gemv_moe_id(ctx, input, ids, out, top_k, 0)
}
pub fn gemv_up(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
out: &mut B::Buffer,
top_k: usize,
) -> Result<()> {
let weight = self.up_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_up: up_stacked not loaded")
})?;
weight.gemv_moe_id(ctx, input, ids, out, top_k, 0)
}
pub fn gemv_down(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
out: &mut B::Buffer,
top_k: usize,
expert_intermediate: usize,
) -> Result<()> {
let weight = self.down_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_down: down_stacked not loaded")
})?;
weight.gemv_moe_id(ctx, input, ids, out, top_k, expert_intermediate)
}
pub fn gemv_gate_up_silu_fused(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
out_silu_stacked: &mut B::Buffer,
top_k: usize,
) -> Result<()> {
let gate = self.gate_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported(
"ExpertStack::gemv_gate_up_silu_fused: gate_stacked not loaded",
)
})?;
let up = self.up_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_gate_up_silu_fused: up_stacked not loaded")
})?;
gate.gemv_moe_id_gate_up_silu(ctx, input, up, ids, out_silu_stacked, top_k)
}
#[allow(clippy::too_many_arguments)]
pub fn gemm_gate(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
ids: &B::Buffer,
tpe: &B::Buffer,
dst: &mut B::Buffer,
args_buf: Option<&B::Buffer>,
top_k: usize,
max_per_expert: usize,
tokens: usize,
) -> Result<()> {
let weight = self.gate_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemm_gate: gate_stacked not loaded")
})?;
match args_buf {
Some(args) => weight.gemm_moe_id_indirect(
ctx,
src1,
ids,
tpe,
dst,
args,
1,
top_k,
max_per_expert,
tokens,
),
None => weight.gemm_moe_id(ctx, src1, ids, tpe, dst, 1, top_k, max_per_expert, tokens),
}
}
#[allow(clippy::too_many_arguments)]
pub fn gemm_up(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
ids: &B::Buffer,
tpe: &B::Buffer,
dst: &mut B::Buffer,
args_buf: Option<&B::Buffer>,
top_k: usize,
max_per_expert: usize,
tokens: usize,
) -> Result<()> {
let weight = self.up_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemm_up: up_stacked not loaded")
})?;
match args_buf {
Some(args) => weight.gemm_moe_id_indirect(
ctx,
src1,
ids,
tpe,
dst,
args,
1,
top_k,
max_per_expert,
tokens,
),
None => weight.gemm_moe_id(ctx, src1, ids, tpe, dst, 1, top_k, max_per_expert, tokens),
}
}
#[allow(clippy::too_many_arguments)]
pub fn gemm_down(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
ids: &B::Buffer,
tpe: &B::Buffer,
dst: &mut B::Buffer,
args_buf: Option<&B::Buffer>,
top_k: usize,
max_per_expert: usize,
tokens: usize,
) -> Result<()> {
let weight = self.down_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemm_down: down_stacked not loaded")
})?;
match args_buf {
Some(args) => weight.gemm_moe_id_indirect(
ctx,
src1,
ids,
tpe,
dst,
args,
top_k,
top_k,
max_per_expert,
tokens,
),
None => weight.gemm_moe_id(
ctx,
src1,
ids,
tpe,
dst,
top_k,
top_k,
max_per_expert,
tokens,
),
}
}
#[allow(clippy::too_many_arguments)]
pub fn gemv_gate_batched(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
dst: &mut B::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()> {
let weight = self.gate_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_gate_batched: gate_stacked not loaded")
})?;
weight.gemv_moe_id_batched(
ctx,
input,
ids,
dst,
m,
top_k,
src1_outer_stride,
src1_inner_stride,
)
}
#[allow(clippy::too_many_arguments)]
pub fn gemv_up_batched(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
dst: &mut B::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()> {
let weight = self.up_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_up_batched: up_stacked not loaded")
})?;
weight.gemv_moe_id_batched(
ctx,
input,
ids,
dst,
m,
top_k,
src1_outer_stride,
src1_inner_stride,
)
}
#[allow(clippy::too_many_arguments)]
pub fn gemv_down_batched(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
dst: &mut B::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()> {
let weight = self.down_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_down_batched: down_stacked not loaded")
})?;
weight.gemv_moe_id_batched(
ctx,
input,
ids,
dst,
m,
top_k,
src1_outer_stride,
src1_inner_stride,
)
}
#[allow(clippy::too_many_arguments)]
pub fn gemv_gate_up_silu_batched_fused(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
silu_out: &mut B::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()> {
let gate = self.gate_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported(
"ExpertStack::gemv_gate_up_silu_batched_fused: gate_stacked not loaded",
)
})?;
let up = self.up_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported(
"ExpertStack::gemv_gate_up_silu_batched_fused: up_stacked not loaded",
)
})?;
gate.gemv_moe_id_gate_up_silu_batched(
ctx,
input,
up,
ids,
silu_out,
m,
top_k,
src1_outer_stride,
src1_inner_stride,
)
}
#[allow(clippy::too_many_arguments)]
pub fn gemv_gate_offset(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
src1_offset: usize,
ids: &B::Buffer,
ids_offset: usize,
dst: &mut B::Buffer,
top_k: usize,
src1_stride: usize,
) -> Result<()> {
let weight = self.gate_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_gate_offset: gate_stacked not loaded")
})?;
weight.gemv_moe_id_offset(
ctx,
src1,
src1_offset,
ids,
ids_offset,
dst,
top_k,
src1_stride,
)
}
#[allow(clippy::too_many_arguments)]
pub fn gemv_up_offset(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
src1_offset: usize,
ids: &B::Buffer,
ids_offset: usize,
dst: &mut B::Buffer,
top_k: usize,
src1_stride: usize,
) -> Result<()> {
let weight = self.up_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_up_offset: up_stacked not loaded")
})?;
weight.gemv_moe_id_offset(
ctx,
src1,
src1_offset,
ids,
ids_offset,
dst,
top_k,
src1_stride,
)
}
#[allow(clippy::too_many_arguments)]
pub fn gemv_down_offset(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
src1_offset: usize,
ids: &B::Buffer,
ids_offset: usize,
dst: &mut B::Buffer,
top_k: usize,
src1_stride: usize,
) -> Result<()> {
let weight = self.down_stacked.as_deref().ok_or_else(|| {
FerrumError::unsupported("ExpertStack::gemv_down_offset: down_stacked not loaded")
})?;
weight.gemv_moe_id_offset(
ctx,
src1,
src1_offset,
ids,
ids_offset,
dst,
top_k,
src1_stride,
)
}
}
impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B> {
pub fn from_dense_stacks(
gate_stack: &[f32],
up_stack: &[f32],
down_stack: &[f32],
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self> {
let gate_up_per_expert = expert_intermediate * hidden_size;
let down_per_expert = hidden_size * expert_intermediate;
check_size(
gate_stack.len(),
num_experts * gate_up_per_expert,
"gate_stack",
)?;
check_size(up_stack.len(), num_experts * gate_up_per_expert, "up_stack")?;
check_size(
down_stack.len(),
num_experts * down_per_expert,
"down_stack",
)?;
let mut gate_up = Vec::with_capacity(num_experts);
let mut down = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let g_off = e * gate_up_per_expert;
let g_slice = &gate_stack[g_off..g_off + gate_up_per_expert];
let u_slice = &up_stack[g_off..g_off + gate_up_per_expert];
let mut fused = Vec::with_capacity(2 * gate_up_per_expert);
fused.extend_from_slice(g_slice);
fused.extend_from_slice(u_slice);
gate_up.push(Box::new(DenseLinear::<B>::from_rows(
&fused,
2 * expert_intermediate,
hidden_size,
)) as Box<dyn Linear<B>>);
let d_off = e * down_per_expert;
let d_slice = &down_stack[d_off..d_off + down_per_expert];
down.push(Box::new(DenseLinear::<B>::from_rows(
d_slice,
hidden_size,
expert_intermediate,
)) as Box<dyn Linear<B>>);
}
Ok(Self {
gate_up,
down,
gate_stacked: None,
up_stacked: None,
down_stacked: None,
gate_up_marlin_stack: None,
down_marlin_stack: None,
})
}
pub fn load_from_gguf(
gguf: &GgufFile,
layer_idx: usize,
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self> {
let runtime_config = moe_dispatch_runtime_config();
if let Some(quant) = Self::try_load_quantised(
gguf,
layer_idx,
num_experts,
hidden_size,
expert_intermediate,
)? {
if runtime_config.moe_load_trace {
eprintln!("[moe-load] layer {layer_idx} → quantised expert path");
}
return Ok(quant);
}
if runtime_config.moe_load_trace {
eprintln!("[moe-load] layer {layer_idx} → eager fp32 dense fallback ⚠");
}
let device = Device::Cpu;
let gate = read_dequant_flat(
gguf,
&format!("blk.{layer_idx}.ffn_gate_exps.weight"),
&device,
)?;
let up = read_dequant_flat(
gguf,
&format!("blk.{layer_idx}.ffn_up_exps.weight"),
&device,
)?;
let down = read_dequant_flat(
gguf,
&format!("blk.{layer_idx}.ffn_down_exps.weight"),
&device,
)?;
Self::from_dense_stacks(
&gate,
&up,
&down,
num_experts,
hidden_size,
expert_intermediate,
)
}
fn try_load_quantised(
gguf: &GgufFile,
layer_idx: usize,
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Option<Self>> {
let device = Device::Cpu;
let gate_name = format!("blk.{layer_idx}.ffn_gate_exps.weight");
let up_name = format!("blk.{layer_idx}.ffn_up_exps.weight");
let down_name = format!("blk.{layer_idx}.ffn_down_exps.weight");
let gate_kind = match quant_kind(gguf, &gate_name)? {
Some(k) => k,
None => return Ok(None),
};
let up_kind = match quant_kind(gguf, &up_name)? {
Some(k) => k,
None => return Ok(None),
};
let down_kind = match quant_kind(gguf, &down_name)? {
Some(k) => k,
None => return Ok(None),
};
let gate_bytes = gguf.tensor_byte_slice(&gate_name).ok_or_else(|| {
FerrumError::model(format!("MoE: tensor_byte_slice failed for '{gate_name}'"))
})?;
let up_bytes = gguf.tensor_byte_slice(&up_name).ok_or_else(|| {
FerrumError::model(format!("MoE: tensor_byte_slice failed for '{up_name}'"))
})?;
let down_bytes = gguf.tensor_byte_slice(&down_name).ok_or_else(|| {
FerrumError::model(format!("MoE: tensor_byte_slice failed for '{down_name}'"))
})?;
let _ = device;
let gate_per = block_bytes_for(
gate_kind,
expert_intermediate * hidden_size,
"ffn_gate_exps",
)?;
let up_per = block_bytes_for(up_kind, expert_intermediate * hidden_size, "ffn_up_exps")?;
let down_per = block_bytes_for(
down_kind,
hidden_size * expert_intermediate,
"ffn_down_exps",
)?;
check_size(
gate_bytes.len(),
num_experts * gate_per,
"ffn_gate_exps bytes",
)?;
check_size(up_bytes.len(), num_experts * up_per, "ffn_up_exps bytes")?;
check_size(
down_bytes.len(),
num_experts * down_per,
"ffn_down_exps bytes",
)?;
let gate_stacked = B::load_quant_experts(
gate_kind,
gate_bytes,
num_experts,
expert_intermediate,
hidden_size,
)
.ok();
let up_stacked = B::load_quant_experts(
up_kind,
up_bytes,
num_experts,
expert_intermediate,
hidden_size,
)
.ok();
let down_stacked = B::load_quant_experts(
down_kind,
down_bytes,
num_experts,
hidden_size,
expert_intermediate,
)
.ok();
let stacked_complete =
gate_stacked.is_some() && up_stacked.is_some() && down_stacked.is_some();
let (gate_up, down) = if stacked_complete {
(Vec::new(), Vec::new())
} else {
let mut gate_up: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
let mut down: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let g_slice = &gate_bytes[e * gate_per..(e + 1) * gate_per];
let u_slice = &up_bytes[e * up_per..(e + 1) * up_per];
let d_slice = &down_bytes[e * down_per..(e + 1) * down_per];
let parts: [(GgufQuantType, &[u8], usize); 2] = [
(gate_kind, g_slice, expert_intermediate),
(up_kind, u_slice, expert_intermediate),
];
let gate_up_e = match QuantLinear::<B>::from_gguf_fused(&parts, hidden_size) {
Ok(q) => q,
Err(_) => return Ok(None),
};
gate_up.push(Box::new(gate_up_e) as Box<dyn Linear<B>>);
let down_e = match QuantLinear::<B>::from_gguf_bytes(
down_kind,
d_slice,
hidden_size,
expert_intermediate,
) {
Ok(q) => q,
Err(_) => return Ok(None),
};
down.push(Box::new(down_e) as Box<dyn Linear<B>>);
}
(gate_up, down)
};
Ok(Some(Self {
gate_up,
down,
gate_stacked,
up_stacked,
down_stacked,
gate_up_marlin_stack: None,
down_marlin_stack: None,
}))
}
pub fn open_and_load(
path: impl AsRef<Path>,
layer_idx: usize,
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self> {
let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
Self::load_from_gguf(
&gguf,
layer_idx,
num_experts,
hidden_size,
expert_intermediate,
)
}
pub fn num_experts(&self) -> usize {
debug_assert_eq!(
self.gate_up.len(),
self.down.len(),
"ExpertStack: gate_up and down disagree on expert count"
);
self.gate_up.len()
}
}
pub struct MoeForwardParams<'a, B: QuantLlmBackend + BackendMoeFused> {
pub ctx: &'a mut B::Context,
pub x: &'a B::Buffer,
pub router_logits: &'a B::Buffer,
pub out: &'a mut B::Buffer,
pub batch: usize,
pub hidden_size: usize,
pub expert_intermediate: usize,
pub num_experts: usize,
pub top_k: usize,
pub norm_topk_prob: bool,
pub experts: &'a ExpertStack<B>,
pub x_single: &'a mut B::Buffer,
pub acc_buf: &'a mut B::Buffer,
pub gate_up_buf: &'a mut B::Buffer,
pub silu_buf: &'a mut B::Buffer,
pub down_buf: &'a mut B::Buffer,
pub zero_hidden: &'a B::Buffer,
}
pub fn moe_forward<B: QuantLlmBackend + BackendMoeFused>(
params: MoeForwardParams<'_, B>,
) -> Result<()> {
let MoeForwardParams {
ctx,
x,
router_logits,
out,
batch,
hidden_size,
expert_intermediate,
num_experts,
top_k,
norm_topk_prob,
experts,
x_single,
acc_buf,
gate_up_buf,
silu_buf,
down_buf,
zero_hidden,
} = params;
let n_experts = experts.num_experts();
if n_experts != num_experts {
return Err(FerrumError::model(format!(
"moe_forward: experts.num_experts() = {n_experts} != cfg.num_experts = {num_experts}"
)));
}
let prof = moe_profile_enabled();
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::sync(ctx);
if let Some(t) = t0 {
MOE_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_SYNC_CALLS.fetch_add(1, Ordering::Relaxed);
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
let logits_host = B::to_vec(router_logits, batch * num_experts);
let route_out =
crate::moe::router::route(&logits_host, batch, num_experts, top_k, norm_topk_prob);
if let Some(t) = t0 {
MOE_HOST_TOPK_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_HOST_TOPK_CALLS.fetch_add(1, Ordering::Relaxed);
}
for b in 0..batch {
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::copy_slice(ctx, x, b * hidden_size, x_single, 0, hidden_size);
B::copy_slice(ctx, zero_hidden, 0, acc_buf, 0, hidden_size);
if let Some(t) = t0 {
MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_COPY_CALLS.fetch_add(2, Ordering::Relaxed);
}
for k in 0..top_k {
let pair = b * top_k + k;
let expert_id = route_out.expert_ids[pair] as usize;
let weight = route_out.expert_weights[pair];
if expert_id >= num_experts {
return Err(FerrumError::model(format!(
"moe_forward: routed expert {expert_id} >= num_experts {num_experts}"
)));
}
let t0 = if prof {
B::sync(ctx);
Some(std::time::Instant::now())
} else {
None
};
experts.gate_up[expert_id].forward(ctx, x_single, gate_up_buf, 1);
if let Some(t) = t0 {
B::sync(ctx);
MOE_GEMV_GATE_UP_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_GEMV_GATE_UP_CALLS.fetch_add(1, Ordering::Relaxed);
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::fused_silu_mul_split(ctx, gate_up_buf, silu_buf, 1, expert_intermediate);
if let Some(t) = t0 {
B::sync(ctx);
MOE_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_SILU_CALLS.fetch_add(1, Ordering::Relaxed);
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
experts.down[expert_id].forward(ctx, silu_buf, down_buf, 1);
if let Some(t) = t0 {
B::sync(ctx);
MOE_GEMV_DOWN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_GEMV_DOWN_CALLS.fetch_add(1, Ordering::Relaxed);
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::scaled_add_inplace(ctx, acc_buf, down_buf, weight, hidden_size);
if let Some(t) = t0 {
B::sync(ctx);
MOE_SCALED_ADD_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_SCALED_ADD_CALLS.fetch_add(1, Ordering::Relaxed);
}
}
let t0 = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::copy_slice(ctx, acc_buf, 0, out, b * hidden_size, hidden_size);
if let Some(t) = t0 {
MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
MOE_COPY_CALLS.fetch_add(1, Ordering::Relaxed);
}
}
Ok(())
}
pub const MOE_BLOCK_SIZE_MAX: usize = 64;
fn pick_moe_block_size(
plan: Option<&MoeBucketPlan>,
num_experts: usize,
use_device_route: bool,
total_pairs: usize,
) -> usize {
pick_moe_block_size_with_config(
moe_dispatch_runtime_config(),
plan,
num_experts,
use_device_route,
total_pairs,
)
}
fn pick_moe_block_size_with_config(
config: &MoeDispatchRuntimeConfig,
plan: Option<&MoeBucketPlan>,
num_experts: usize,
use_device_route: bool,
total_pairs: usize,
) -> usize {
const CANDIDATES: &[usize] = &[64, 32, 16];
const PADDING_BUDGET: f64 = 1.30; if let Some(bs) = config.moe_block_size {
return bs;
}
if use_device_route {
if let Some(bs) = config.moe_large_m_block_size {
if total_pairs >= config.moe_large_m_min_pairs {
return bs;
}
}
return 16;
}
let Some(plan) = plan else {
return 16;
};
let m_e: Vec<usize> = (0..num_experts)
.map(|e| plan.expert_offsets[e + 1] - plan.expert_offsets[e])
.collect();
let total_actual: usize = m_e.iter().sum();
if total_actual == 0 {
return 16;
}
for &bs in CANDIDATES {
let total_padded: usize = m_e.iter().map(|&m| m.div_ceil(bs) * bs).sum();
if (total_padded as f64) <= (total_actual as f64) * PADDING_BUDGET {
return bs;
}
}
16
}
pub struct MoeBucketPlan {
pub expert_offsets: Vec<usize>,
pub packed_token_idx: Vec<u32>,
pub pairs_by_token: Vec<i32>,
pub pair_weights: Vec<f32>,
cursors: Vec<usize>,
}
impl MoeBucketPlan {
pub fn empty() -> Self {
Self {
expert_offsets: Vec::new(),
packed_token_idx: Vec::new(),
pairs_by_token: Vec::new(),
pair_weights: Vec::new(),
cursors: Vec::new(),
}
}
pub fn build(route: &RouterOutput, batch: usize, num_experts: usize, top_k: usize) -> Self {
let mut p = Self::empty();
p.rebuild_into(route, batch, num_experts, top_k);
p
}
pub fn rebuild_into(
&mut self,
route: &RouterOutput,
batch: usize,
num_experts: usize,
top_k: usize,
) {
debug_assert_eq!(route.expert_ids.len(), batch * top_k);
debug_assert_eq!(route.expert_weights.len(), batch * top_k);
let total_pairs = batch * top_k;
self.expert_offsets.clear();
self.expert_offsets.resize(num_experts + 1, 0);
self.packed_token_idx.clear();
self.packed_token_idx.resize(total_pairs, 0);
self.pairs_by_token.clear();
self.pairs_by_token.resize(total_pairs, -1);
for &eid in &route.expert_ids {
self.expert_offsets[eid as usize + 1] += 1;
}
for e in 0..num_experts {
self.expert_offsets[e + 1] += self.expert_offsets[e];
}
self.cursors.clear();
self.cursors
.extend_from_slice(&self.expert_offsets[..num_experts]);
for b in 0..batch {
for k in 0..top_k {
let pair_flat = b * top_k + k;
let eid = route.expert_ids[pair_flat] as usize;
let slot = self.cursors[eid];
self.cursors[eid] += 1;
self.packed_token_idx[slot] = b as u32;
self.pairs_by_token[pair_flat] = slot as i32;
}
}
self.pair_weights.clear();
self.pair_weights.extend_from_slice(&route.expert_weights);
}
}
pub struct MoeRouteScratch {
pub output: RouterOutput,
pub probs: Vec<f32>,
pub plan: MoeBucketPlan,
}
impl MoeRouteScratch {
pub fn new() -> Self {
Self {
output: RouterOutput::empty(),
probs: Vec::new(),
plan: MoeBucketPlan::empty(),
}
}
}
impl Default for MoeRouteScratch {
fn default() -> Self {
Self::new()
}
}
pub struct DeviceRouteScratch<'a, B: crate::moe::dispatch::Backend> {
pub selected_ids: &'a mut B::Buffer,
pub pair_weights: &'a mut B::Buffer,
pub pairs_by_token: &'a mut B::Buffer,
pub packed_token_idx: &'a mut B::Buffer,
pub expert_offsets: &'a mut B::Buffer,
pub sorted_tokens: &'a mut B::Buffer,
pub block_ids: &'a mut B::Buffer,
pub total_post_pad: &'a mut B::Buffer,
}
pub struct MoeForwardBucketedParams<'a, B: QuantLlmBackend + BackendMoeFused> {
pub ctx: &'a mut B::Context,
pub x: &'a B::Buffer,
pub router_logits: &'a B::Buffer,
pub out: &'a mut B::Buffer,
pub batch: usize,
pub hidden_size: usize,
pub expert_intermediate: usize,
pub num_experts: usize,
pub top_k: usize,
pub norm_topk_prob: bool,
pub experts: &'a ExpertStack<B>,
pub x_packed: &'a mut B::Buffer,
pub gate_up_packed: &'a mut B::Buffer,
pub silu_packed: &'a mut B::Buffer,
pub down_packed: &'a mut B::Buffer,
pub route_scratch: &'a mut MoeRouteScratch,
pub device_route: Option<DeviceRouteScratch<'a, B>>,
}
pub fn moe_forward_bucketed<B: QuantLlmBackend + BackendMoeFused>(
params: MoeForwardBucketedParams<'_, B>,
) -> Result<()> {
let MoeForwardBucketedParams {
ctx,
x,
router_logits,
out,
batch,
hidden_size,
expert_intermediate,
num_experts,
top_k,
norm_topk_prob,
experts,
x_packed,
gate_up_packed,
silu_packed,
down_packed,
route_scratch,
device_route,
} = params;
if experts.num_experts() != num_experts {
return Err(FerrumError::model(format!(
"moe_forward_bucketed: experts {} != num_experts {num_experts}",
experts.num_experts()
)));
}
let runtime_config = moe_dispatch_runtime_config();
let prof = runtime_config.moe_profile || runtime_config.decode_op_profile;
if prof {
MOE_BUCKET_LAYER_CALLS.fetch_add(1, Ordering::Relaxed);
}
let use_vllm_moe = runtime_config.vllm_moe;
let use_device_route = device_route.is_some() && use_vllm_moe && !runtime_config.moe_host_route;
let use_vllm_pair_ids = use_device_route && runtime_config.vllm_moe_pair_ids;
let mut dr_kept: Option<DeviceRouteScratch<'_, B>> = if use_device_route {
let dr = device_route.expect("device_route is Some when use_device_route");
B::route_topk_softmax(
ctx,
router_logits,
dr.selected_ids,
dr.pair_weights,
batch,
num_experts,
top_k,
norm_topk_prob,
)?;
if !use_vllm_pair_ids {
B::moe_build_pairs_by_token(
ctx,
dr.selected_ids,
dr.pairs_by_token,
dr.packed_token_idx,
dr.expert_offsets,
batch * top_k,
num_experts,
top_k,
)?;
}
Some(dr)
} else {
None
};
let plan: Option<&crate::moe::MoeBucketPlan> = if !use_device_route {
let t_route_total = if prof {
Some(std::time::Instant::now())
} else {
None
};
let gpu_route = B::try_gpu_route_topk_into_host(
ctx,
router_logits,
&mut route_scratch.output.expert_ids,
&mut route_scratch.output.expert_weights,
batch,
num_experts,
top_k,
norm_topk_prob,
);
if gpu_route.is_err() {
let t_sync = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::sync(ctx);
if let Some(t) = t_sync {
MOE_BUCKET_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
let t_d2h = if prof {
Some(std::time::Instant::now())
} else {
None
};
let logits_host = B::to_vec(router_logits, batch * num_experts);
if let Some(t) = t_d2h {
MOE_BUCKET_D2H_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
let t_route = if prof {
Some(std::time::Instant::now())
} else {
None
};
crate::moe::router::route_into(
&logits_host,
batch,
num_experts,
top_k,
norm_topk_prob,
&mut route_scratch.output,
&mut route_scratch.probs,
);
if let Some(t) = t_route {
MOE_BUCKET_ROUTE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
} else if let Some(t) = t_route_total {
MOE_BUCKET_ROUTE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
let t_plan = if prof {
Some(std::time::Instant::now())
} else {
None
};
route_scratch
.plan
.rebuild_into(&route_scratch.output, batch, num_experts, top_k);
if let Some(t) = t_plan {
MOE_BUCKET_PLAN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
Some(&route_scratch.plan)
} else {
None
};
if !use_vllm_pair_ids {
let t_gather = if prof {
Some(std::time::Instant::now())
} else {
None
};
if let Some(ref dr) = dr_kept {
B::embedding_lookup_dev(
ctx,
x,
dr.packed_token_idx,
x_packed,
batch * top_k,
hidden_size,
);
} else {
let plan = plan.expect("plan is Some when !use_device_route");
B::embedding_lookup(ctx, x, &plan.packed_token_idx, x_packed, hidden_size);
}
if let Some(t) = t_gather {
B::sync(ctx);
MOE_BUCKET_GATHER_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
}
let gate_up_dim_per_expert = 2 * expert_intermediate;
let down_n_per_expert = hidden_size;
let gu_store = experts.gate_up_stacked_store(0).ok_or_else(|| {
FerrumError::model(
"moe_forward_bucketed requires stacked gate_up store \
(load via Qwen3MoeModel::new_safetensors)",
)
})?;
let zero_marlin_workspace = !use_vllm_moe || runtime_config.vllm_moe_zero_ws;
if zero_marlin_workspace {
let _ = gu_store.zero_workspace(ctx);
}
let total_pairs_active = batch * top_k;
let max_block_size: usize = 64;
let moe_block_size: usize = pick_moe_block_size_with_config(
runtime_config,
plan,
num_experts,
use_device_route,
total_pairs_active,
);
debug_assert!(
moe_block_size <= max_block_size,
"moe_block_size {moe_block_size} exceeds scratch worst-case {max_block_size}"
);
let sorted_max_size = batch * top_k + num_experts * moe_block_size;
let vllm_routing_owned: Option<ferrum_kernels::backend::MoeRouting<B>> =
if use_vllm_moe && !use_device_route {
let plan = plan.expect("plan is Some when host vllm builder runs");
let mut padded_offsets = Vec::with_capacity(num_experts + 1);
let mut acc = 0usize;
for e in 0..num_experts {
padded_offsets.push(acc);
let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
let pe = m_e.div_ceil(moe_block_size) * moe_block_size;
acc += pe;
}
padded_offsets.push(acc);
let total_padded = acc;
let total_blocks = total_padded / moe_block_size;
let sentinel = total_pairs_active as i32;
let mut sorted_token_ids = vec![sentinel; total_padded];
let mut expert_ids = vec![0i32; total_blocks];
for e in 0..num_experts {
let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
if m_e == 0 {
continue;
}
let p_off = padded_offsets[e];
let real_off = plan.expert_offsets[e];
for i in 0..m_e {
sorted_token_ids[p_off + i] = (real_off + i) as i32;
}
let blocks_for_e = (padded_offsets[e + 1] - p_off) / moe_block_size;
let block_start = p_off / moe_block_size;
for b in 0..blocks_for_e {
expert_ids[block_start + b] = e as i32;
}
}
let num_tokens_past_padded = vec![total_padded as i32];
Some(B::upload_moe_routing(
ctx,
&sorted_token_ids,
&expert_ids,
&num_tokens_past_padded,
)?)
} else {
None
};
if use_device_route {
let dr = dr_kept
.as_mut()
.expect("dr_kept is Some when use_device_route");
if use_vllm_pair_ids {
B::moe_align_block_size_pair_ids(
ctx,
dr.selected_ids,
dr.sorted_tokens,
dr.block_ids,
dr.total_post_pad,
batch * top_k,
num_experts,
moe_block_size,
sorted_max_size,
)?;
} else {
B::moe_align_block_size(
ctx,
dr.selected_ids,
dr.sorted_tokens,
dr.block_ids,
dr.total_post_pad,
batch * top_k,
num_experts,
moe_block_size,
sorted_max_size,
)?;
}
}
let vllm_refs: Option<(&B::Buffer, &B::Buffer, &B::Buffer)> = if use_device_route {
let dr = dr_kept
.as_ref()
.expect("dr_kept is Some when use_device_route");
Some((&*dr.sorted_tokens, &*dr.block_ids, &*dr.total_post_pad))
} else if let Some(r) = vllm_routing_owned.as_ref() {
Some((
&r.sorted_token_ids,
&r.expert_ids,
&r.num_tokens_past_padded,
))
} else {
None
};
let phase1_dispatches: Vec<(usize, usize, usize, usize)> = if vllm_refs.is_none() {
let plan = plan.expect("plan is Some when batched GEMM path runs");
let mut v: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
if m_e == 0 {
continue;
}
let pair_off = plan.expert_offsets[e];
v.push((e, pair_off, pair_off, m_e));
}
v.sort_by(|a, b| b.3.cmp(&a.3).then_with(|| a.0.cmp(&b.0)));
v
} else {
Vec::new()
};
let t_gemm1 = if prof {
Some(std::time::Instant::now())
} else {
None
};
if let Some((sorted_tokens, block_ids, total_post_pad)) = vllm_refs {
if use_vllm_pair_ids {
gu_store.gemm_phase_vllm(
ctx,
x,
sorted_tokens,
block_ids,
total_post_pad,
gate_up_packed,
batch,
moe_block_size,
top_k,
)?;
} else {
gu_store.gemm_phase_vllm(
ctx,
x_packed,
sorted_tokens,
block_ids,
total_post_pad,
gate_up_packed,
total_pairs_active,
moe_block_size,
1, )?;
}
} else {
gu_store.gemm_phase_batched(
ctx,
x_packed,
&phase1_dispatches,
gate_up_packed,
hidden_size,
)?;
}
if let Some(t) = t_gemm1 {
B::sync(ctx);
MOE_BUCKET_GEMM1_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
let total_pairs_active = batch * top_k;
let t_silu = if prof {
Some(std::time::Instant::now())
} else {
None
};
B::fused_silu_mul_split(
ctx,
gate_up_packed,
silu_packed,
total_pairs_active,
expert_intermediate,
);
if let Some(t) = t_silu {
B::sync(ctx);
MOE_BUCKET_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
let d_store = experts.down_stacked_store(0).ok_or_else(|| {
FerrumError::model(
"moe_forward_bucketed requires stacked down store \
(load via Qwen3MoeModel::new_safetensors)",
)
})?;
if zero_marlin_workspace {
let _ = d_store.zero_workspace(ctx);
}
let phase3_dispatches: Vec<(usize, usize, usize, usize)> = if vllm_refs.is_none() {
let plan = plan.expect("plan is Some when batched GEMM path runs");
let mut v: Vec<(usize, usize, usize, usize)> = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let m_e = plan.expert_offsets[e + 1] - plan.expert_offsets[e];
if m_e == 0 {
continue;
}
let pair_off = plan.expert_offsets[e];
v.push((e, pair_off, pair_off, m_e));
}
v.sort_by(|a, b| b.3.cmp(&a.3).then_with(|| a.0.cmp(&b.0)));
v
} else {
Vec::new()
};
let t_gemm3 = if prof {
Some(std::time::Instant::now())
} else {
None
};
if let Some((sorted_tokens, block_ids, total_post_pad)) = vllm_refs {
d_store.gemm_phase_vllm(
ctx,
silu_packed,
sorted_tokens,
block_ids,
total_post_pad,
down_packed,
total_pairs_active,
moe_block_size,
1,
)?;
} else {
d_store.gemm_phase_batched(
ctx,
silu_packed,
&phase3_dispatches,
down_packed,
expert_intermediate,
)?;
}
if let Some(t) = t_gemm3 {
B::sync(ctx);
MOE_BUCKET_GEMM3_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
let total_pairs = batch * top_k;
let t_comb = if prof {
Some(std::time::Instant::now())
} else {
None
};
if use_vllm_pair_ids {
let dr = dr_kept
.as_ref()
.expect("dr_kept is Some when use_vllm_pair_ids");
B::weighted_sum_batched(
ctx,
down_packed,
dr.pair_weights,
out,
batch,
top_k,
hidden_size,
)?;
} else {
let (pairs_ref, weights_ref);
let _pairs_owned;
let _weights_owned;
if let Some(ref dr) = dr_kept {
pairs_ref = &*dr.pairs_by_token;
weights_ref = &*dr.pair_weights;
} else {
let plan = plan.expect("plan is Some when host moe_combine runs");
_pairs_owned = B::from_slice_typed::<i32>(&plan.pairs_by_token);
_weights_owned = B::from_slice_typed::<f32>(&plan.pair_weights);
pairs_ref = &_pairs_owned;
weights_ref = &_weights_owned;
}
B::moe_combine(
ctx,
down_packed,
pairs_ref,
weights_ref,
out,
batch,
hidden_size,
top_k,
total_pairs,
);
}
if let Some(t) = t_comb {
B::sync(ctx);
MOE_BUCKET_COMBINE_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
}
Ok(())
}
pub fn moe_forward_cpu(
x: &[f32],
batch: usize,
hidden_size: usize,
expert_intermediate: usize,
top_k: usize,
router: &RouterOutput,
experts: &ExpertStack<CpuBackend>,
out: &mut Vec<f32>,
) -> Result<()> {
let n_experts = experts.num_experts();
if x.len() != batch * hidden_size {
return Err(FerrumError::model(format!(
"moe_forward_cpu: x len {} doesn't match batch*hidden = {}*{} = {}",
x.len(),
batch,
hidden_size,
batch * hidden_size
)));
}
if router.expert_ids.len() != batch * top_k {
return Err(FerrumError::model(format!(
"moe_forward_cpu: router has {} expert_ids but expected batch*top_k = {}*{} = {}",
router.expert_ids.len(),
batch,
top_k,
batch * top_k
)));
}
out.clear();
out.resize(batch * hidden_size, 0.0);
let mut ctx = <CpuBackend as Backend>::new_context();
let mut x_b: Vec<f32> = vec![0.0; hidden_size];
let mut gate_up_buf: Vec<f32> = vec![0.0; 2 * expert_intermediate];
let mut silu_mul_buf: Vec<f32> = vec![0.0; expert_intermediate];
let mut down_out: Vec<f32> = vec![0.0; hidden_size];
for b in 0..batch {
x_b.copy_from_slice(&x[b * hidden_size..(b + 1) * hidden_size]);
for k in 0..top_k {
let pair_idx = b * top_k + k;
let expert_id = router.expert_ids[pair_idx] as usize;
let weight = router.expert_weights[pair_idx];
if expert_id >= n_experts {
return Err(FerrumError::model(format!(
"moe_forward_cpu: router selected expert {expert_id} >= num_experts {n_experts}"
)));
}
experts.gate_up[expert_id].forward(&mut ctx, &x_b, &mut gate_up_buf, 1);
<CpuBackend as Backend>::fused_silu_mul_split(
&mut ctx,
&gate_up_buf,
&mut silu_mul_buf,
1,
expert_intermediate,
);
experts.down[expert_id].forward(&mut ctx, &silu_mul_buf, &mut down_out, 1);
let out_row = &mut out[b * hidden_size..(b + 1) * hidden_size];
for (o, d) in out_row.iter_mut().zip(down_out.iter()) {
*o += weight * *d;
}
}
}
Ok(())
}
fn check_size(actual: usize, expected: usize, label: &str) -> Result<()> {
if actual != expected {
return Err(FerrumError::model(format!(
"ExpertStack: {label} size mismatch (got {actual}, expected {expected})"
)));
}
Ok(())
}
fn quant_kind(gguf: &GgufFile, name: &str) -> Result<Option<GgufQuantType>> {
let info = gguf.tensor_info(name).ok_or_else(|| {
FerrumError::model(format!("ExpertStack: tensor info missing for '{name}'"))
})?;
Ok(match info.ggml_dtype {
GgmlDType::Q4K => Some(GgufQuantType::Q4K),
GgmlDType::Q6K => Some(GgufQuantType::Q6K),
_ => None,
})
}
fn block_bytes_for(kind: GgufQuantType, n_elems: usize, label: &str) -> Result<usize> {
const QK_K: usize = 256;
if n_elems % QK_K != 0 {
return Err(FerrumError::model(format!(
"ExpertStack {label}: per-expert element count {n_elems} not a multiple of {QK_K}"
)));
}
let block_bytes = match kind {
GgufQuantType::Q4K => 144,
GgufQuantType::Q6K => 210,
other => {
return Err(FerrumError::model(format!(
"ExpertStack {label}: unsupported k-quant flavour {other:?}"
)))
}
};
Ok((n_elems / QK_K) * block_bytes)
}
fn read_dequant_flat(gguf: &GgufFile, name: &str, device: &Device) -> Result<Vec<f32>> {
let qt = gguf.read_tensor(name, device).map_err(candle_to_ferrum)?;
let dense = qt.dequantize(device).map_err(candle_to_ferrum)?;
let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
flat.to_vec1::<f32>().map_err(candle_to_ferrum)
}
fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
FerrumError::model(format!("candle: {e}"))
}
#[allow(dead_code)]
type _CandleResult<T> = CandleResult<T>;
#[cfg(test)]
mod tests {
use super::{pick_moe_block_size_with_config, MoeDispatchRuntimeConfig};
#[test]
fn moe_dispatch_runtime_config_parses_m3_startup_knobs() {
let config = MoeDispatchRuntimeConfig::from_env_vars([
("FERRUM_MOE_PROFILE", "0"),
("FERRUM_DECODE_OP_PROFILE", "true"),
("FERRUM_VLLM_MOE_ZERO_WS", "1"),
("FERRUM_VLLM_MOE_PAIR_IDS", "1"),
("FERRUM_MOE_LOAD_TRACE", ""),
("FERRUM_MOE_BLOCK_SIZE", "8"),
("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
("FERRUM_MOE_LARGE_M_MIN_PAIRS", "2048"),
("FERRUM_VLLM_MOE", "1"),
("FERRUM_MOE_HOST_ROUTE", "1"),
]);
assert!(config.moe_profile);
assert!(config.decode_op_profile);
assert!(config.vllm_moe_zero_ws);
assert!(config.vllm_moe_pair_ids);
assert!(config.moe_load_trace);
assert_eq!(config.moe_block_size, Some(8));
assert_eq!(config.moe_large_m_block_size, Some(64));
assert_eq!(config.moe_large_m_min_pairs, 2048);
assert!(config.vllm_moe);
assert!(config.moe_host_route);
}
#[test]
fn moe_dispatch_runtime_config_bounds_invalid_block_values() {
let config = MoeDispatchRuntimeConfig::from_env_vars([
("FERRUM_MOE_BLOCK_SIZE", "12"),
("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "128"),
("FERRUM_MOE_LARGE_M_MIN_PAIRS", "bad"),
("FERRUM_VLLM_MOE_ZERO_WS", "true"),
("FERRUM_MOE_HOST_ROUTE", "0"),
]);
assert_eq!(config.moe_block_size, None);
assert_eq!(config.moe_large_m_block_size, None);
assert_eq!(config.moe_large_m_min_pairs, 1024);
assert!(!config.vllm_moe_zero_ws);
assert!(!config.moe_host_route);
}
#[test]
fn device_route_large_m_block_size_is_thresholded() {
let config = MoeDispatchRuntimeConfig::from_env_vars([
("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
("FERRUM_MOE_LARGE_M_MIN_PAIRS", "1024"),
]);
assert_eq!(
pick_moe_block_size_with_config(&config, None, 128, true, 256),
16
);
assert_eq!(
pick_moe_block_size_with_config(&config, None, 128, true, 1024),
64
);
}
#[test]
fn global_moe_block_size_override_still_wins() {
let config = MoeDispatchRuntimeConfig::from_env_vars([
("FERRUM_MOE_BLOCK_SIZE", "32"),
("FERRUM_MOE_LARGE_M_BLOCK_SIZE", "64"),
("FERRUM_MOE_LARGE_M_MIN_PAIRS", "1024"),
]);
assert_eq!(
pick_moe_block_size_with_config(&config, None, 128, true, 2048),
32
);
}
}