use std::collections::HashMap;
use std::sync::atomic::AtomicU64;
use std::sync::OnceLock;
use ferrum_kernels::backend::{Backend, KvCache};
use ferrum_quantization::WeightLoader;
use ferrum_types::{FerrumError, Result};
use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyLayer, RopeCache};
use crate::moe::{moe_forward, ExpertStack};
use crate::moe_config::Qwen3MoeConfig;
static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
static MOE_TIME_US: AtomicU64 = AtomicU64::new(0);
static MOE_CALLS: AtomicU64 = AtomicU64::new(0);
static DEC_ROUTE_US: AtomicU64 = AtomicU64::new(0);
static DEC_GATE_US: AtomicU64 = AtomicU64::new(0);
static DEC_UP_US: AtomicU64 = AtomicU64::new(0);
static DEC_SILU_US: AtomicU64 = AtomicU64::new(0);
static DEC_DOWN_US: AtomicU64 = AtomicU64::new(0);
static DEC_WSUM_US: AtomicU64 = AtomicU64::new(0);
static DEC_EMBED_US: AtomicU64 = AtomicU64::new(0);
static DEC_FINAL_NORM_US: AtomicU64 = AtomicU64::new(0);
static DEC_LM_HEAD_US: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_GATE_US: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_GATE_CALLS: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_UP_US: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_UP_CALLS: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_SILU_US: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_DOWN_US: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_WSUM_US: AtomicU64 = AtomicU64::new(0);
static MOE_PREFILL_WSUM_CALLS: AtomicU64 = AtomicU64::new(0);
static MOE_BATCHED_DECODE_ROUTE_US: AtomicU64 = AtomicU64::new(0);
static MOE_BATCHED_DECODE_GATE_US: AtomicU64 = AtomicU64::new(0);
static MOE_BATCHED_DECODE_UP_US: AtomicU64 = AtomicU64::new(0);
static MOE_BATCHED_DECODE_SILU_US: AtomicU64 = AtomicU64::new(0);
static MOE_BATCHED_DECODE_DOWN_US: AtomicU64 = AtomicU64::new(0);
static MOE_BATCHED_DECODE_WSUM_US: AtomicU64 = AtomicU64::new(0);
static BD_DENSE_US: AtomicU64 = AtomicU64::new(0); static BD_ATTN_PERITEM_US: AtomicU64 = AtomicU64::new(0); static BD_MOE_US: AtomicU64 = AtomicU64::new(0); static BD_LAYER_CALLS: AtomicU64 = AtomicU64::new(0);
pub struct Qwen3MoeLayerState<B: Backend> {
pub router: Box<dyn ferrum_quantization::Linear<B>>,
pub experts: ExpertStack<B>,
}
pub struct Qwen3MoeScratch<B: Backend> {
pub residual: Option<B::Buffer>,
pub norm_out: B::Buffer,
pub qkv_out: B::Buffer,
pub q_buf: B::Buffer,
pub k_buf: B::Buffer,
pub v_buf: B::Buffer,
pub q_head_major: B::Buffer,
pub k_head_major: B::Buffer,
pub v_head_major: B::Buffer,
pub attn_head_major_out: B::Buffer,
pub attn_flat: B::Buffer,
pub o_proj_out: B::Buffer,
pub router_logits: B::Buffer,
pub gate_up_buf: B::Buffer,
pub silu_buf: B::Buffer,
pub down_buf: B::Buffer,
pub x_single: B::Buffer,
pub acc_buf: B::Buffer,
pub moe_out: B::Buffer,
pub zero_hidden: B::Buffer,
pub gate_out_stacked: B::Buffer,
pub up_out_stacked: B::Buffer,
pub silu_stacked: B::Buffer,
pub down_out_stacked: B::Buffer,
pub ids_buf: B::Buffer,
pub weights_buf: B::Buffer,
pub selected_ids_buf: B::Buffer,
pub gate_up_args_buf: B::Buffer,
pub down_args_buf: B::Buffer,
pub ids_2d: B::Buffer,
pub tpe_buf: B::Buffer,
pub weights_2d: B::Buffer,
pub last_hidden: B::Buffer,
pub last_normed: B::Buffer,
pub logits: B::Buffer,
pub batch_logits: B::Buffer,
pub q_single: Option<B::Buffer>,
pub k_single: Option<B::Buffer>,
pub v_single: Option<B::Buffer>,
pub q_head_major_single: Option<B::Buffer>,
pub k_head_major_single: Option<B::Buffer>,
pub v_head_major_single: Option<B::Buffer>,
pub attn_head_major_single: Option<B::Buffer>,
pub paged_batch_q: Option<B::Buffer>,
pub paged_batch_o: Option<B::Buffer>,
pub paged_batch_block_tables: Option<B::Buffer>,
pub paged_batch_context_lens: Option<B::Buffer>,
pub paged_max_blocks_per_seq: usize,
pub max_tokens: usize,
}
impl<B: Backend> Qwen3MoeScratch<B> {
fn alloc(cfg: &Qwen3MoeConfig, max_tokens: usize) -> Self {
let h = cfg.base.hidden_size;
let q_dim = cfg.base.num_heads * cfg.base.head_dim;
let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
let qkv_dim = q_dim + 2 * kv_dim;
let t = max_tokens;
let inter = cfg.expert_intermediate_size;
let n_exp = cfg.num_experts;
let vocab = cfg.base.vocab_size;
Self {
residual: Some(B::alloc(t * h)),
norm_out: B::alloc(t * h),
qkv_out: B::alloc(t * qkv_dim),
q_buf: B::alloc(t * q_dim),
k_buf: B::alloc(t * kv_dim),
v_buf: B::alloc(t * kv_dim),
q_head_major: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
k_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
v_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
attn_head_major_out: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
attn_flat: B::alloc(t * q_dim),
o_proj_out: B::alloc(t * h),
router_logits: B::alloc(t * n_exp),
gate_up_buf: B::alloc(2 * inter),
silu_buf: B::alloc(inter),
down_buf: B::alloc(h),
x_single: B::alloc(h),
acc_buf: B::alloc(h),
moe_out: B::alloc(t * h),
zero_hidden: B::from_slice(&vec![0.0f32; h]),
gate_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
up_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
silu_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
down_out_stacked: B::alloc(t * cfg.num_experts_per_tok * h),
ids_buf: B::from_slice_i32(&vec![0i32; cfg.num_experts_per_tok]),
weights_buf: B::from_slice(&vec![0.0f32; cfg.num_experts_per_tok]),
selected_ids_buf: B::from_slice_i32(&vec![0i32; t * cfg.num_experts_per_tok]),
gate_up_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
down_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
ids_2d: B::from_slice_i32(&vec![0i32; n_exp * t * cfg.num_experts_per_tok]),
tpe_buf: B::from_slice_i32(&vec![0i32; n_exp]),
weights_2d: B::from_slice(&vec![0.0f32; t * cfg.num_experts_per_tok]),
last_hidden: B::alloc(h),
last_normed: B::alloc(h),
logits: B::alloc(vocab),
batch_logits: B::alloc(t * vocab),
q_single: None,
k_single: None,
v_single: None,
q_head_major_single: None,
k_head_major_single: None,
v_head_major_single: None,
attn_head_major_single: None,
paged_batch_q: None,
paged_batch_o: None,
paged_batch_block_tables: None,
paged_batch_context_lens: None,
paged_max_blocks_per_seq: 0,
max_tokens: t,
}
}
fn enable_paged_batch(
&mut self,
cfg: &Qwen3MoeConfig,
max_seqs: usize,
max_blocks_per_seq: usize,
) {
if self.paged_batch_q.is_some() {
return;
}
let q_dim = cfg.base.num_heads * cfg.base.head_dim;
self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
self.paged_max_blocks_per_seq = max_blocks_per_seq;
}
fn enable_batched_decode_scratch(&mut self, cfg: &Qwen3MoeConfig) {
if self.q_single.is_some() {
return;
}
let q_dim = cfg.base.num_heads * cfg.base.head_dim;
let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
self.q_single = Some(B::alloc(q_dim));
self.k_single = Some(B::alloc(kv_dim));
self.v_single = Some(B::alloc(kv_dim));
self.q_head_major_single = Some(B::alloc(q_dim));
self.k_head_major_single = Some(B::alloc(kv_dim));
self.v_head_major_single = Some(B::alloc(kv_dim));
self.attn_head_major_single = Some(B::alloc(q_dim));
}
}
pub struct Qwen3MoeModel<B: Backend> {
pub cfg: Qwen3MoeConfig,
pub runtime_cfg: LlmRuntimeConfig,
pub embed: B::Buffer,
pub attn_layers: Vec<LlamaFamilyLayer<B>>,
pub moe_layers: Vec<Qwen3MoeLayerState<B>>,
pub final_norm_w: B::Buffer,
pub lm_head: Box<dyn ferrum_quantization::Linear<B>>,
pub rope: RopeCache<B>,
pub scratch: Qwen3MoeScratch<B>,
pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
kv_free_pool: Vec<Vec<KvCache<B>>>,
pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
}
impl<B: Backend> Qwen3MoeModel<B> {
pub fn new(
cfg: Qwen3MoeConfig,
loader: &dyn WeightLoader<B>,
gguf: &ferrum_quantization::gguf::GgufFile,
) -> Result<Self> {
{
let mut ctx = B::new_context();
B::reset_graph(&mut ctx);
}
let rope = build_rope_cache::<B>(&cfg.base);
let scratch = Qwen3MoeScratch::alloc(&cfg, 1);
let embed = loader.load_tensor("model.embed_tokens.weight")?;
let mut attn_layers = Vec::with_capacity(cfg.base.num_layers);
let mut moe_layers = Vec::with_capacity(cfg.base.num_layers);
for li in 0..cfg.base.num_layers {
let prefix = format!("model.layers.{li}");
let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
let post_ln_w =
loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
let gate_up_proj: Box<dyn ferrum_quantization::Linear<B>> =
stub_linear::<B>(2 * cfg.expert_intermediate_size, cfg.base.hidden_size);
let down_proj: Box<dyn ferrum_quantization::Linear<B>> =
stub_linear::<B>(cfg.base.hidden_size, cfg.expert_intermediate_size);
let (q_norm_w, k_norm_w) = if cfg.base.has_qk_norm {
let q = loader
.load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
.ok();
let k = loader
.load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
.ok();
(q, k)
} else {
(None, None)
};
attn_layers.push(LlamaFamilyLayer {
input_ln_w,
qkv_proj,
q_norm_w,
k_norm_w,
o_proj,
post_ln_w,
gate_up_proj,
down_proj,
});
let router = loader.load_linear(&format!("{prefix}.mlp.router"))?;
if router.in_features() != cfg.base.hidden_size {
return Err(FerrumError::model(format!(
"router layer {li}: in_features {} != hidden {}",
router.in_features(),
cfg.base.hidden_size
)));
}
if router.out_features() != cfg.num_experts {
return Err(FerrumError::model(format!(
"router layer {li}: out_features {} != num_experts {}",
router.out_features(),
cfg.num_experts
)));
}
let experts = ExpertStack::<B>::load_from_gguf(
gguf,
li,
cfg.num_experts,
cfg.base.hidden_size,
cfg.expert_intermediate_size,
)?;
moe_layers.push(Qwen3MoeLayerState { router, experts });
}
let final_norm_w = loader.load_tensor("model.norm.weight")?;
let lm_head = if loader.has_tensor("lm_head.weight") {
loader.load_linear("lm_head")?
} else {
tracing::info!(
"Qwen3MoeModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
);
loader.load_linear("model.embed_tokens")?
};
let runtime_cfg = cfg.base.to_runtime();
Ok(Self {
cfg,
runtime_cfg,
embed,
attn_layers,
moe_layers,
final_norm_w,
lm_head,
rope,
scratch,
kv_caches: HashMap::new(),
kv_free_pool: Vec::new(),
paged_pools: None,
paged_block_alloc: None,
})
}
pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
if self.scratch.max_tokens < tokens {
{
let mut ctx = B::new_context();
B::reset_graph(&mut ctx);
}
self.scratch = Qwen3MoeScratch::alloc(&self.cfg, tokens);
}
}
pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
if self.kv_caches.contains_key(cache_id) {
return;
}
let nkv = self.cfg.base.num_kv_heads;
let hd = self.cfg.base.head_dim;
let model_max = self.cfg.base.max_seq_len;
const DEFAULT_KV_CAPACITY: usize = 4096;
let max = std::env::var("FERRUM_KV_CAPACITY")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.map(|cap| cap.min(model_max))
.unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
let paged = std::env::var("FERRUM_METAL_PAGED_KV")
.map(|v| v == "1")
.unwrap_or(false);
const PAGED_BLOCK_SIZE: usize = 16;
let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(16);
let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
let total_pool_blocks = max_seqs * max_blocks_per_seq;
if paged && self.paged_pools.is_none() {
let mut pools = Vec::with_capacity(self.cfg.base.num_layers);
for _ in 0..self.cfg.base.num_layers {
let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
}
self.paged_pools = Some(pools);
self.paged_block_alloc = Some(std::sync::Mutex::new(
crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
));
}
if paged {
self.scratch
.enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
}
let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
(0..self.cfg.base.num_layers)
.map(|_| {
if paged {
let mut block_table = B::alloc_u32(max_blocks_per_seq);
let _ = &mut block_table; let mut context_lens = B::alloc_u32(1);
let mut bt_ctx = B::new_context();
B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
B::sync(&mut bt_ctx);
KvCache {
k: B::alloc(1),
v: B::alloc(1),
len: 0,
capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
num_kv_heads: nkv,
head_dim: hd,
block_size: PAGED_BLOCK_SIZE,
block_table: Some(block_table),
context_lens: Some(context_lens),
paged_block_indices: Vec::new(),
}
} else {
KvCache {
k: B::alloc(nkv * max * hd),
v: B::alloc(nkv * max * hd),
len: 0,
capacity: max,
num_kv_heads: nkv,
head_dim: hd,
block_size: 0,
block_table: None,
context_lens: None,
paged_block_indices: Vec::new(),
}
}
})
.collect()
});
if paged {
let alloc_arc = self
.paged_block_alloc
.as_ref()
.expect("paged_block_alloc must be initialised when paged=true");
let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
Ok(idx) => idx,
Err(e) => {
drop(alloc);
self.kv_free_pool.push(caches);
eprintln!(
"[ferrum] paged KV pool exhausted on ensure_kv for \
cache_id={cache_id:?}: {e}. Increase \
FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
throttle concurrent requests.",
);
return;
}
};
let mut padded = block_indices.clone();
padded.resize(max_blocks_per_seq, 0);
let mut ctx_tmp = B::new_context();
for c in caches.iter_mut() {
if let Some(bt) = c.block_table.as_mut() {
B::write_u32(&mut ctx_tmp, bt, &padded);
}
c.paged_block_indices = block_indices.clone();
}
B::sync(&mut ctx_tmp);
}
for c in caches.iter_mut() {
c.len = 0;
if let Some(cl) = c.context_lens.as_mut() {
let mut ctx_tmp = B::new_context();
B::write_u32(&mut ctx_tmp, cl, &[0u32]);
B::sync(&mut ctx_tmp);
}
}
self.kv_caches.insert(cache_id.to_string(), caches);
}
pub(crate) fn forward_layer(
&mut self,
ctx: &mut B::Context,
li: usize,
cache_id: &str,
residual: &mut B::Buffer,
pos_offset: usize,
tokens: usize,
next_layer_idx: Option<usize>,
prev_did_norm_fusion: bool,
) -> Result<bool> {
let cfg_base = &self.cfg.base;
let h = cfg_base.hidden_size;
let nh = cfg_base.num_heads;
let nkv = cfg_base.num_kv_heads;
let hd = cfg_base.head_dim;
let eps = cfg_base.rms_norm_eps;
let q_dim = nh * hd;
let kv_dim = nkv * hd;
let attn_layer = &self.attn_layers[li];
let moe_layer = &self.moe_layers[li];
let attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
B::sync(ctx);
Some(std::time::Instant::now())
} else {
None
};
if !prev_did_norm_fusion {
B::rms_norm(
ctx,
residual,
&attn_layer.input_ln_w,
eps,
&mut self.scratch.norm_out,
tokens,
h,
);
}
attn_layer.qkv_proj.forward(
ctx,
&self.scratch.norm_out,
&mut self.scratch.qkv_out,
tokens,
);
let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
let dummy = &attn_layer.input_ln_w;
let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy);
let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy);
let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
if let Some(pools) = self.paged_pools.as_mut() {
let pool = &mut pools[li];
Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
} else {
None
};
let caches = self
.kv_caches
.get_mut(cache_id)
.expect("ensure_kv must be called before forward_layer");
let cache = &mut caches[li];
let cache_len_before = cache.len;
let cache_capacity = cache.capacity;
if cache_len_before + tokens > cache_capacity {
panic!(
"KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
cache_len_before + tokens
);
}
let used_qkv_into_cache = if cache.block_size > 0 {
let bt = cache
.block_table
.as_ref()
.expect("paged cache missing block_table");
let num_blocks_per_seq = cache.capacity / cache.block_size;
let (pool_k_ptr, pool_v_ptr) =
paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
let pool_k = unsafe { &mut *pool_k_ptr };
let pool_v = unsafe { &mut *pool_v_ptr };
B::split_qkv_norm_rope_into_paged_cache(
ctx,
&self.scratch.qkv_out,
0,
q_norm_w,
k_norm_w,
&self.rope.cos,
&self.rope.sin,
&mut self.scratch.q_head_major,
0,
pool_k,
pool_v,
bt,
tokens,
nh,
nkv,
hd,
pos_offset,
eps,
qk_mode,
cache_len_before,
cache.block_size,
num_blocks_per_seq,
)
.is_ok()
} else {
B::split_qkv_norm_rope_into_cache(
ctx,
&self.scratch.qkv_out,
q_norm_w,
k_norm_w,
&self.rope.cos,
&self.rope.sin,
&mut self.scratch.q_head_major,
&mut cache.k,
&mut cache.v,
tokens,
nh,
nkv,
hd,
pos_offset,
eps,
qk_mode,
cache_len_before,
cache_capacity,
)
.is_ok()
};
if !used_qkv_into_cache {
let used_fused_qkv = B::split_qkv_norm_rope(
ctx,
&self.scratch.qkv_out,
q_norm_w,
k_norm_w,
&self.rope.cos,
&self.rope.sin,
&mut self.scratch.q_head_major,
&mut self.scratch.k_head_major,
&mut self.scratch.v_head_major,
tokens,
nh,
nkv,
hd,
pos_offset,
eps,
qk_mode,
)
.is_ok();
if !used_fused_qkv {
B::split_qkv(
ctx,
&self.scratch.qkv_out,
&mut self.scratch.q_buf,
&mut self.scratch.k_buf,
&mut self.scratch.v_buf,
tokens,
q_dim,
kv_dim,
);
B::qk_norm_rope(
ctx,
&self.scratch.q_buf,
q_norm_w,
&self.rope.cos,
&self.rope.sin,
&mut self.scratch.q_head_major,
tokens,
nh,
hd,
pos_offset,
eps,
qk_mode,
);
B::qk_norm_rope(
ctx,
&self.scratch.k_buf,
k_norm_w,
&self.rope.cos,
&self.rope.sin,
&mut self.scratch.k_head_major,
tokens,
nkv,
hd,
pos_offset,
eps,
qk_mode,
);
B::qk_norm_rope(
ctx,
&self.scratch.v_buf,
dummy,
&self.rope.cos,
&self.rope.sin,
&mut self.scratch.v_head_major,
tokens,
nkv,
hd,
pos_offset,
eps,
0,
);
}
B::kv_cache_append_head_major(
ctx,
&mut cache.k,
&mut cache.v,
cache.len,
cache.capacity,
&self.scratch.k_head_major,
&self.scratch.v_head_major,
tokens,
nkv,
hd,
);
}
cache.len += tokens;
let kv_len = cache.len;
let kv_stride = cache.capacity;
if cache.block_size > 0 {
let bt = cache
.block_table
.as_ref()
.expect("paged cache missing block_table");
let cl_buf = cache
.context_lens
.as_mut()
.expect("paged cache missing context_lens");
let num_blocks_per_seq = cache.capacity / cache.block_size;
let (pool_k_ptr, pool_v_ptr) =
paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
let pool_k = unsafe { &*pool_k_ptr };
let pool_v = unsafe { &*pool_v_ptr };
let final_kv_len = cache.len as u32;
B::write_u32(ctx, cl_buf, &[final_kv_len]);
B::paged_decode_attention(
ctx,
&self.scratch.q_head_major,
pool_k,
pool_v,
&mut self.scratch.attn_head_major_out,
bt,
cl_buf,
1, nh,
nkv,
hd,
cache.block_size,
num_blocks_per_seq,
tokens,
)
.expect("paged_decode_attention");
let _ = kv_stride; } else {
let attn_cfg = ferrum_kernels::backend::AttnConfig {
num_heads: nh,
num_kv_heads: nkv,
head_dim: hd,
causal: true,
scale: 1.0 / (hd as f32).sqrt(),
kv_seq_stride: kv_stride,
sliding_window: cfg_base.sliding_window,
};
B::flash_attention(
ctx,
&self.scratch.q_head_major,
&cache.k,
&cache.v,
&mut self.scratch.attn_head_major_out,
1,
tokens,
kv_len,
pos_offset,
&attn_cfg,
);
}
if let Some(t0) = attn_t0 {
B::sync(ctx);
ATTN_TIME_US.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let attn_token_major = if tokens == 1 {
&self.scratch.attn_head_major_out
} else {
B::transpose_head_to_token(
ctx,
&self.scratch.attn_head_major_out,
&mut self.scratch.attn_flat,
tokens,
nh,
hd,
);
&self.scratch.attn_flat
};
attn_layer
.o_proj
.forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
B::fused_add_rms_norm(
ctx,
residual,
&self.scratch.o_proj_out,
&attn_layer.post_ln_w,
eps,
&mut self.scratch.norm_out,
tokens,
h,
);
let moe_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
B::sync(ctx);
Some(std::time::Instant::now())
} else {
None
};
moe_layer.router.forward(
ctx,
&self.scratch.norm_out,
&mut self.scratch.router_logits,
tokens,
);
let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
&& moe_layer.experts.up_stacked.is_some()
&& moe_layer.experts.down_stacked.is_some();
let decode_fast_path = stacked_path_available && tokens == 1;
let did_norm_fusion = decode_fast_path && next_layer_idx.is_some();
if stacked_path_available {
if tokens > 1 {
self.moe_forward_batched_prefill(ctx, li, tokens)?;
} else {
self.moe_forward_stacked(ctx, li, tokens, residual, next_layer_idx)?;
}
} else {
moe_forward::<B>(
ctx,
&self.scratch.norm_out,
&self.scratch.router_logits,
&mut self.scratch.moe_out,
tokens,
h,
self.cfg.expert_intermediate_size,
self.cfg.num_experts,
self.cfg.num_experts_per_tok,
self.cfg.norm_topk_prob,
&moe_layer.experts,
&mut self.scratch.x_single,
&mut self.scratch.acc_buf,
&mut self.scratch.gate_up_buf,
&mut self.scratch.silu_buf,
&mut self.scratch.down_buf,
&self.scratch.zero_hidden,
)?;
}
if !decode_fast_path {
B::add_inplace(ctx, residual, &self.scratch.moe_out, tokens * h);
}
if let Some(t0) = moe_t0 {
B::sync(ctx);
MOE_TIME_US.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
MOE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
Ok(did_norm_fusion)
}
fn moe_forward_stacked(
&mut self,
ctx: &mut B::Context,
li: usize,
tokens: usize,
residual: &mut B::Buffer,
next_layer_idx: Option<usize>,
) -> Result<()> {
let cfg = &self.cfg;
let next_norm_w_ptr: Option<*const B::Buffer> =
next_layer_idx.map(|idx| &self.attn_layers[idx].input_ln_w as *const _);
let next_norm_w: Option<&B::Buffer> = next_norm_w_ptr.map(|p| unsafe { &*p });
moe_forward_stacked_decode_impl::<B>(
ctx,
&self.moe_layers[li],
&mut self.scratch,
cfg.base.hidden_size,
cfg.expert_intermediate_size,
cfg.num_experts_per_tok,
cfg.num_experts,
cfg.norm_topk_prob,
tokens,
residual,
next_norm_w,
cfg.base.rms_norm_eps,
)
}
fn moe_forward_batched_prefill(
&mut self,
ctx: &mut B::Context,
li: usize,
tokens: usize,
) -> Result<()> {
let cfg = &self.cfg;
moe_forward_batched_prefill_impl::<B>(
ctx,
&self.moe_layers[li],
&mut self.scratch,
cfg.base.hidden_size,
cfg.expert_intermediate_size,
cfg.num_experts_per_tok,
cfg.num_experts,
cfg.norm_topk_prob,
tokens,
)
}
pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
let seq_len = tokens.len();
assert!(seq_len > 0);
self.ensure_scratch(seq_len);
self.ensure_kv(cache_id);
let pos_offset = self
.kv_caches
.get(cache_id)
.and_then(|layers| layers.first())
.map(|c| c.len)
.unwrap_or(0);
let h = self.cfg.base.hidden_size;
let vocab = self.cfg.base.vocab_size;
let mut ctx = B::new_context();
let prefill_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
B::sync(&mut ctx);
for c in [
&ATTN_TIME_US,
&ATTN_CALLS,
&MOE_TIME_US,
&MOE_CALLS,
&MOE_PREFILL_HOST_TOPK_US,
&MOE_PREFILL_HOST_TOPK_CALLS,
&MOE_PREFILL_GATE_US,
&MOE_PREFILL_GATE_CALLS,
&MOE_PREFILL_UP_US,
&MOE_PREFILL_UP_CALLS,
&MOE_PREFILL_SILU_US,
&MOE_PREFILL_SILU_CALLS,
&MOE_PREFILL_DOWN_US,
&MOE_PREFILL_DOWN_CALLS,
&MOE_PREFILL_WSUM_US,
&MOE_PREFILL_WSUM_CALLS,
] {
c.store(0, std::sync::atomic::Ordering::Relaxed);
}
Some(std::time::Instant::now())
} else {
None
};
let mut residual = self
.scratch
.residual
.take()
.expect("scratch residual missing (previous call didn't restore)");
B::embedding_lookup(&mut ctx, &self.embed, tokens, &mut residual, h);
let mut prev_did_norm_fusion = false;
let num_layers = self.cfg.base.num_layers;
for li in 0..num_layers {
let next_layer_idx = if li + 1 < num_layers {
Some(li + 1)
} else {
None
};
prev_did_norm_fusion = self
.forward_layer(
&mut ctx,
li,
cache_id,
&mut residual,
pos_offset,
seq_len,
next_layer_idx,
prev_did_norm_fusion,
)
.expect("forward_layer");
}
B::copy_slice(
&mut ctx,
&residual,
(seq_len - 1) * h,
&mut self.scratch.last_hidden,
0,
h,
);
B::rms_norm(
&mut ctx,
&self.scratch.last_hidden,
&self.final_norm_w,
self.cfg.base.rms_norm_eps,
&mut self.scratch.last_normed,
1,
h,
);
self.lm_head.forward(
&mut ctx,
&self.scratch.last_normed,
&mut self.scratch.logits,
1,
);
B::sync(&mut ctx);
if let Some(t0) = prefill_t0 {
let total_us = t0.elapsed().as_micros() as u64;
let attn_us = ATTN_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
let attn_n = ATTN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
let moe_us = MOE_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
let moe_n = MOE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
let other_us = total_us.saturating_sub(attn_us).saturating_sub(moe_us);
eprintln!(
"[prefill-profile] tokens={seq_len} total={} ms ({:.0} t/s)",
total_us / 1000,
seq_len as f64 * 1e6 / total_us as f64
);
let bucket = |label: &str, n: u64, us: u64| {
if n > 0 {
eprintln!(
" {label:>6}: {:7} ms ({:5.1}%) over {n:4} calls",
us / 1000,
us as f64 * 100.0 / total_us as f64
);
}
};
bucket("attn", attn_n, attn_us);
bucket("moe", moe_n, moe_us);
bucket("other", 1, other_us);
let host_us = MOE_PREFILL_HOST_TOPK_US.load(std::sync::atomic::Ordering::Relaxed);
let gate_us = MOE_PREFILL_GATE_US.load(std::sync::atomic::Ordering::Relaxed);
let up_us = MOE_PREFILL_UP_US.load(std::sync::atomic::Ordering::Relaxed);
let silu_us = MOE_PREFILL_SILU_US.load(std::sync::atomic::Ordering::Relaxed);
let down_us = MOE_PREFILL_DOWN_US.load(std::sync::atomic::Ordering::Relaxed);
let wsum_us = MOE_PREFILL_WSUM_US.load(std::sync::atomic::Ordering::Relaxed);
let host_n = MOE_PREFILL_HOST_TOPK_CALLS.load(std::sync::atomic::Ordering::Relaxed);
let gate_n = MOE_PREFILL_GATE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
let up_n = MOE_PREFILL_UP_CALLS.load(std::sync::atomic::Ordering::Relaxed);
let silu_n = MOE_PREFILL_SILU_CALLS.load(std::sync::atomic::Ordering::Relaxed);
let down_n = MOE_PREFILL_DOWN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
let wsum_n = MOE_PREFILL_WSUM_CALLS.load(std::sync::atomic::Ordering::Relaxed);
bucket(" host", host_n, host_us);
bucket(" gate", gate_n, gate_us);
bucket(" up", up_n, up_us);
bucket(" silu", silu_n, silu_us);
bucket(" down", down_n, down_us);
bucket(" wsum", wsum_n, wsum_us);
}
self.scratch.residual = Some(residual);
B::to_vec(&self.scratch.logits, vocab)
}
pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
self.ensure_scratch(1);
self.ensure_kv(cache_id);
let h = self.cfg.base.hidden_size;
let vocab = self.cfg.base.vocab_size;
let mut ctx = B::new_context();
let decode_t0 = if std::env::var("FERRUM_MOE_PROFILE").is_ok() {
Some(std::time::Instant::now())
} else {
None
};
let stage_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
B::sync(&mut ctx);
for c in [
&ATTN_TIME_US,
&ATTN_CALLS,
&MOE_TIME_US,
&MOE_CALLS,
&DEC_ROUTE_US,
&DEC_GATE_US,
&DEC_UP_US,
&DEC_SILU_US,
&DEC_DOWN_US,
&DEC_WSUM_US,
&DEC_EMBED_US,
&DEC_FINAL_NORM_US,
&DEC_LM_HEAD_US,
] {
c.store(0, std::sync::atomic::Ordering::Relaxed);
}
Some(std::time::Instant::now())
} else {
None
};
let prof = stage_t0.is_some();
let mark = |ctx: &mut B::Context, c: &AtomicU64, t0: std::time::Instant| {
if prof {
B::sync(ctx);
c.fetch_add(
t0.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
};
let mt0 = std::time::Instant::now();
let mut residual = self
.scratch
.residual
.take()
.expect("scratch residual missing (previous call didn't restore)");
let t0 = std::time::Instant::now();
B::embedding_lookup(&mut ctx, &self.embed, &[token], &mut residual, h);
mark(&mut ctx, &DEC_EMBED_US, t0);
let _ = mt0;
let mut prev_did_norm_fusion = false;
let num_layers = self.cfg.base.num_layers;
for li in 0..num_layers {
let next_layer_idx = if li + 1 < num_layers {
Some(li + 1)
} else {
None
};
prev_did_norm_fusion = self
.forward_layer(
&mut ctx,
li,
cache_id,
&mut residual,
pos as usize,
1,
next_layer_idx,
prev_did_norm_fusion,
)
.expect("forward_layer");
}
let t0 = std::time::Instant::now();
B::rms_norm(
&mut ctx,
&residual,
&self.final_norm_w,
self.cfg.base.rms_norm_eps,
&mut self.scratch.last_normed,
1,
h,
);
mark(&mut ctx, &DEC_FINAL_NORM_US, t0);
let t0 = std::time::Instant::now();
self.lm_head.forward(
&mut ctx,
&self.scratch.last_normed,
&mut self.scratch.logits,
1,
);
mark(&mut ctx, &DEC_LM_HEAD_US, t0);
B::sync(&mut ctx);
self.scratch.residual = Some(residual);
if let Some(t0) = stage_t0 {
use std::sync::atomic::Ordering;
let total_us = t0.elapsed().as_micros() as u64;
let attn_us = ATTN_TIME_US.swap(0, Ordering::Relaxed);
let moe_us = MOE_TIME_US.swap(0, Ordering::Relaxed);
let route = DEC_ROUTE_US.swap(0, Ordering::Relaxed);
let gate = DEC_GATE_US.swap(0, Ordering::Relaxed);
let up = DEC_UP_US.swap(0, Ordering::Relaxed);
let silu = DEC_SILU_US.swap(0, Ordering::Relaxed);
let down = DEC_DOWN_US.swap(0, Ordering::Relaxed);
let wsum = DEC_WSUM_US.swap(0, Ordering::Relaxed);
let embed = DEC_EMBED_US.swap(0, Ordering::Relaxed);
let fnorm = DEC_FINAL_NORM_US.swap(0, Ordering::Relaxed);
let lmhead = DEC_LM_HEAD_US.swap(0, Ordering::Relaxed);
let other = total_us.saturating_sub(attn_us + moe_us + embed + fnorm + lmhead);
let pct = |us: u64| -> f64 {
if total_us == 0 {
0.0
} else {
100.0 * us as f64 / total_us as f64
}
};
eprintln!(
"[decode-prof] total={} ms | attn={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | embed={} fnorm={} lmhead={} other={} ({:.1}%)",
total_us / 1000,
attn_us / 1000, pct(attn_us),
moe_us / 1000, pct(moe_us),
route / 1000, gate / 1000, up / 1000, silu / 1000, down / 1000, wsum / 1000,
embed / 1000, fnorm / 1000, lmhead / 1000,
other / 1000, pct(other),
);
}
if let Some(t0) = decode_t0 {
use crate::moe::dispatch::*;
use std::sync::atomic::Ordering;
let total_us = t0.elapsed().as_micros() as u64;
let sync_us = MOE_SYNC_US.swap(0, Ordering::Relaxed);
let sync_n = MOE_SYNC_CALLS.swap(0, Ordering::Relaxed);
let topk_us = MOE_HOST_TOPK_US.swap(0, Ordering::Relaxed);
let topk_n = MOE_HOST_TOPK_CALLS.swap(0, Ordering::Relaxed);
let gu_us = MOE_GEMV_GATE_UP_US.swap(0, Ordering::Relaxed);
let gu_n = MOE_GEMV_GATE_UP_CALLS.swap(0, Ordering::Relaxed);
let silu_us = MOE_SILU_US.swap(0, Ordering::Relaxed);
let silu_n = MOE_SILU_CALLS.swap(0, Ordering::Relaxed);
let dn_us = MOE_GEMV_DOWN_US.swap(0, Ordering::Relaxed);
let dn_n = MOE_GEMV_DOWN_CALLS.swap(0, Ordering::Relaxed);
let sa_us = MOE_SCALED_ADD_US.swap(0, Ordering::Relaxed);
let sa_n = MOE_SCALED_ADD_CALLS.swap(0, Ordering::Relaxed);
let cp_us = MOE_COPY_US.swap(0, Ordering::Relaxed);
let cp_n = MOE_COPY_CALLS.swap(0, Ordering::Relaxed);
eprintln!(
"[moe-prof] decode total={} ms | sync={} ms ({}x) | host_topk={} ms ({}x) | gate_up={} ms ({}x) | silu={} ms ({}x) | down={} ms ({}x) | scaled_add={} ms ({}x) | copy={} ms ({}x)",
total_us / 1000,
sync_us / 1000, sync_n,
topk_us / 1000, topk_n,
gu_us / 1000, gu_n,
silu_us / 1000, silu_n,
dn_us / 1000, dn_n,
sa_us / 1000, sa_n,
cp_us / 1000, cp_n,
);
}
B::to_vec(&self.scratch.logits, vocab)
}
pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
let m = batch.len();
if m == 0 {
return Vec::new();
}
if m == 1 {
let (cid, tok, pos) = &batch[0];
return vec![self.decode_internal(cid, *tok, *pos)];
}
let prof_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
Some(std::time::Instant::now())
} else {
None
};
for (cid, _, _) in batch {
self.ensure_kv(cid);
}
self.ensure_scratch(m);
self.scratch.enable_batched_decode_scratch(&self.cfg);
let h = self.cfg.base.hidden_size;
let vocab = self.cfg.base.vocab_size;
let mut ctx = B::new_context();
let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
let mut residual = self
.scratch
.residual
.take()
.expect("scratch residual missing (previous call didn't restore)");
B::embedding_lookup(&mut ctx, &self.embed, &tokens, &mut residual, h);
for li in 0..self.cfg.base.num_layers {
self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m)
.expect("forward_layer_batched_decode");
}
B::rms_norm(
&mut ctx,
&residual,
&self.final_norm_w,
self.cfg.base.rms_norm_eps,
&mut self.scratch.norm_out,
m,
h,
);
self.lm_head.forward(
&mut ctx,
&self.scratch.norm_out,
&mut self.scratch.batch_logits,
m,
);
B::sync(&mut ctx);
self.scratch.residual = Some(residual);
let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
if let Some(t0) = prof_t0 {
use std::sync::atomic::Ordering;
let total_us = t0.elapsed().as_micros() as u64;
let dense = BD_DENSE_US.swap(0, Ordering::Relaxed);
let attn = BD_ATTN_PERITEM_US.swap(0, Ordering::Relaxed);
let moe = BD_MOE_US.swap(0, Ordering::Relaxed);
let layers = BD_LAYER_CALLS.swap(0, Ordering::Relaxed);
let other = total_us.saturating_sub(dense + attn + moe);
let pct = |us: u64| -> f64 {
if total_us == 0 {
0.0
} else {
100.0 * us as f64 / total_us as f64
}
};
let moe_route = MOE_BATCHED_DECODE_ROUTE_US.swap(0, Ordering::Relaxed);
let moe_gate = MOE_BATCHED_DECODE_GATE_US.swap(0, Ordering::Relaxed);
let moe_up = MOE_BATCHED_DECODE_UP_US.swap(0, Ordering::Relaxed);
let moe_silu = MOE_BATCHED_DECODE_SILU_US.swap(0, Ordering::Relaxed);
let moe_down = MOE_BATCHED_DECODE_DOWN_US.swap(0, Ordering::Relaxed);
let moe_wsum = MOE_BATCHED_DECODE_WSUM_US.swap(0, Ordering::Relaxed);
eprintln!(
"[batched-decode-prof] m={} layers={} total={} ms | dense={} ({:.1}%) | attn_peritem={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | other={} ({:.1}%)",
m, layers, total_us / 1000,
dense / 1000, pct(dense),
attn / 1000, pct(attn),
moe / 1000, pct(moe),
moe_route / 1000, moe_gate / 1000, moe_up / 1000,
moe_silu / 1000, moe_down / 1000, moe_wsum / 1000,
other / 1000, pct(other),
);
}
(0..m)
.map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
.collect()
}
fn forward_layer_batched_decode(
&mut self,
ctx: &mut B::Context,
li: usize,
batch: &[(String, u32, u32)],
residual: &mut B::Buffer,
m: usize,
) -> Result<()> {
let cfg_base = &self.cfg.base;
let h = cfg_base.hidden_size;
let nh = cfg_base.num_heads;
let nkv = cfg_base.num_kv_heads;
let hd = cfg_base.head_dim;
let eps = cfg_base.rms_norm_eps;
let q_dim = nh * hd;
let kv_dim = nkv * hd;
let attn_layer = &self.attn_layers[li];
let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
let dummy_w = &attn_layer.input_ln_w;
let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy_w);
let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy_w);
let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
let stage_t0 = || -> Option<std::time::Instant> {
if prof {
Some(std::time::Instant::now())
} else {
None
}
};
let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
if let Some(t) = t0 {
B::sync(ctx);
c.fetch_add(
t.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
};
if prof {
BD_LAYER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
let dense_t0 = stage_t0();
B::rms_norm(
ctx,
residual,
&attn_layer.input_ln_w,
eps,
&mut self.scratch.norm_out,
m,
h,
);
attn_layer
.qkv_proj
.forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
let is_paged = self.paged_pools.is_some();
if is_paged {
stage_end(dense_t0, ctx, &BD_DENSE_US);
let attn_t0 = stage_t0();
let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
let block_size = 16; let qkv_stride = q_dim + 2 * kv_dim;
let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
let pool_ptr = {
let pools = self.paged_pools.as_mut().unwrap();
(
&mut pools[li].0 as *mut B::Buffer,
&mut pools[li].1 as *mut B::Buffer,
)
};
let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
let pos_i = *pos as usize;
let caches = self
.kv_caches
.get(cache_id)
.expect("paged batched: cache not present");
let cache = &caches[li];
let bt = cache
.block_table
.as_ref()
.expect("paged batched: block_table missing");
let cache_len_before = cache.len;
let bt_ptr = bt as *const B::Buffer;
let bt_safe: &B::Buffer = unsafe { &*bt_ptr };
B::split_qkv_norm_rope_into_paged_cache(
ctx,
&self.scratch.qkv_out,
(i as u64) * qkv_stride_bytes,
q_norm_w,
k_norm_w,
&self.rope.cos,
&self.rope.sin,
self.scratch
.paged_batch_q
.as_mut()
.expect("paged_batch_q missing"),
(i as u64) * q_head_major_size_bytes,
pool_k,
pool_v,
bt_safe,
1,
nh,
nkv,
hd,
pos_i,
eps,
qk_mode,
cache_len_before,
block_size,
max_blocks_per_seq,
)
.expect("split_qkv_norm_rope_into_paged_cache (batched)");
}
let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
let mut stacked_cl: Vec<u32> = vec![0u32; m];
for (i, (cache_id, _, _)) in batch.iter().enumerate() {
let caches = self
.kv_caches
.get_mut(cache_id)
.expect("paged batched: cache not present");
let cache = &mut caches[li];
cache.len += 1;
let len = cache.len as u32;
stacked_cl[i] = len;
let blocks = &cache.paged_block_indices;
let n_to_copy = blocks.len().min(max_blocks_per_seq);
stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
.copy_from_slice(&blocks[..n_to_copy]);
}
let bt_buf = self
.scratch
.paged_batch_block_tables
.as_mut()
.expect("paged_batch_block_tables missing");
B::write_u32(ctx, bt_buf, &stacked_bt);
let cl_buf = self
.scratch
.paged_batch_context_lens
.as_mut()
.expect("paged_batch_context_lens missing");
B::write_u32(ctx, cl_buf, &stacked_cl);
let bt_ptr =
self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
let cl_ptr =
self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
let bt_safe = unsafe { &*bt_ptr };
let cl_safe = unsafe { &*cl_ptr };
let q_safe = unsafe { &*q_ptr };
let o_safe = unsafe { &mut *o_ptr };
B::paged_decode_attention(
ctx,
q_safe,
pool_k,
pool_v,
o_safe,
bt_safe,
cl_safe,
m,
nh,
nkv,
hd,
block_size,
max_blocks_per_seq,
1, )
.expect("paged batched decode");
for i in 0..m {
B::copy_slice(
ctx,
self.scratch.paged_batch_o.as_ref().unwrap(),
i * q_dim,
&mut self.scratch.attn_flat,
i * q_dim,
q_dim,
);
}
stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
} else {
B::split_qkv(
ctx,
&self.scratch.qkv_out,
&mut self.scratch.q_buf,
&mut self.scratch.k_buf,
&mut self.scratch.v_buf,
m,
q_dim,
kv_dim,
);
let q_single = self
.scratch
.q_single
.as_ref()
.expect("q_single missing — enable_batched_decode_scratch not called")
as *const B::Buffer;
let k_single =
self.scratch.k_single.as_ref().expect("k_single missing") as *const B::Buffer;
let v_single =
self.scratch.v_single.as_ref().expect("v_single missing") as *const B::Buffer;
let q_hm_single =
self.scratch
.q_head_major_single
.as_mut()
.expect("q_head_major_single missing") as *mut B::Buffer;
let k_hm_single =
self.scratch
.k_head_major_single
.as_mut()
.expect("k_head_major_single missing") as *mut B::Buffer;
let v_hm_single =
self.scratch
.v_head_major_single
.as_mut()
.expect("v_head_major_single missing") as *mut B::Buffer;
let attn_hm_single =
self.scratch
.attn_head_major_single
.as_mut()
.expect("attn_head_major_single missing") as *mut B::Buffer;
stage_end(dense_t0, ctx, &BD_DENSE_US);
let attn_t0 = stage_t0();
for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
let pos_i = *pos as usize;
let q_single_ref = unsafe { &*q_single };
let k_single_ref = unsafe { &*k_single };
let v_single_ref = unsafe { &*v_single };
let q_hm_single_mut = unsafe { &mut *q_hm_single };
let k_hm_single_mut = unsafe { &mut *k_hm_single };
let v_hm_single_mut = unsafe { &mut *v_hm_single };
let attn_hm_single_mut = unsafe { &mut *attn_hm_single };
B::copy_slice(
ctx,
&self.scratch.q_buf,
i * q_dim,
self.scratch.q_single.as_mut().unwrap(),
0,
q_dim,
);
B::copy_slice(
ctx,
&self.scratch.k_buf,
i * kv_dim,
self.scratch.k_single.as_mut().unwrap(),
0,
kv_dim,
);
B::copy_slice(
ctx,
&self.scratch.v_buf,
i * kv_dim,
self.scratch.v_single.as_mut().unwrap(),
0,
kv_dim,
);
B::qk_norm_rope(
ctx,
q_single_ref,
q_norm_w,
&self.rope.cos,
&self.rope.sin,
q_hm_single_mut,
1,
nh,
hd,
pos_i,
eps,
qk_mode,
);
B::qk_norm_rope(
ctx,
k_single_ref,
k_norm_w,
&self.rope.cos,
&self.rope.sin,
k_hm_single_mut,
1,
nkv,
hd,
pos_i,
eps,
qk_mode,
);
B::qk_norm_rope(
ctx,
v_single_ref,
dummy_w,
&self.rope.cos,
&self.rope.sin,
v_hm_single_mut,
1,
nkv,
hd,
pos_i,
eps,
0,
);
let caches = self
.kv_caches
.get_mut(cache_id)
.expect("ensure_kv must be called before forward_layer_batched");
let cache = &mut caches[li];
B::kv_cache_append_head_major(
ctx,
&mut cache.k,
&mut cache.v,
cache.len,
cache.capacity,
k_hm_single_mut,
v_hm_single_mut,
1,
nkv,
hd,
);
cache.len += 1;
let kv_len = cache.len;
let kv_stride = cache.capacity;
let attn_cfg = ferrum_kernels::backend::AttnConfig {
num_heads: nh,
num_kv_heads: nkv,
head_dim: hd,
causal: true,
scale: 1.0 / (hd as f32).sqrt(),
kv_seq_stride: kv_stride,
sliding_window: cfg_base.sliding_window,
};
B::flash_attention(
ctx,
q_hm_single_mut,
&cache.k,
&cache.v,
attn_hm_single_mut,
1,
1,
kv_len,
pos_i,
&attn_cfg,
);
B::copy_slice(
ctx,
attn_hm_single_mut,
0,
&mut self.scratch.attn_flat,
i * q_dim,
q_dim,
);
}
stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
}
let post_attn_t0 = stage_t0();
attn_layer.o_proj.forward(
ctx,
&self.scratch.attn_flat,
&mut self.scratch.o_proj_out,
m,
);
B::fused_add_rms_norm(
ctx,
residual,
&self.scratch.o_proj_out,
&attn_layer.post_ln_w,
eps,
&mut self.scratch.norm_out,
m,
h,
);
stage_end(post_attn_t0, ctx, &BD_DENSE_US);
let moe_t0 = stage_t0();
let moe_layer = &self.moe_layers[li];
moe_layer.router.forward(
ctx,
&self.scratch.norm_out,
&mut self.scratch.router_logits,
m,
);
let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
&& moe_layer.experts.up_stacked.is_some()
&& moe_layer.experts.down_stacked.is_some();
let legacy_prefill_threshold: usize = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(8);
let new_prefill_threshold: usize = std::env::var("FERRUM_MOE_PREFILL_THRESHOLD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(32);
let new_batched_enabled = stacked_path_available
&& B::supports_batched_moe_gemv()
&& std::env::var("FERRUM_MOE_BATCHED_DECODE").as_deref() == Ok("1");
let use_prefill_batched = if new_batched_enabled {
stacked_path_available && m >= new_prefill_threshold
} else {
stacked_path_available && m >= legacy_prefill_threshold
};
let use_batched_decode = new_batched_enabled && !use_prefill_batched && m >= 2;
if use_prefill_batched {
moe_forward_batched_prefill_impl::<B>(
ctx,
moe_layer,
&mut self.scratch,
h,
self.cfg.expert_intermediate_size,
self.cfg.num_experts_per_tok,
self.cfg.num_experts,
self.cfg.norm_topk_prob,
m,
)?;
} else if use_batched_decode {
moe_forward_batched_decode_impl::<B>(
ctx,
moe_layer,
&mut self.scratch,
h,
self.cfg.expert_intermediate_size,
self.cfg.num_experts_per_tok,
self.cfg.num_experts,
self.cfg.norm_topk_prob,
m,
)?;
} else if stacked_path_available {
let inter = self.cfg.expert_intermediate_size;
let top_k = self.cfg.num_experts_per_tok;
let n_exp = self.cfg.num_experts;
let norm_topk_prob = self.cfg.norm_topk_prob;
let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
B::route_topk_softmax(
ctx,
&self.scratch.router_logits,
&mut self.scratch.selected_ids_buf,
&mut self.scratch.weights_2d,
m,
n_exp,
top_k,
norm_topk_prob,
)?;
for i in 0..m {
let ids_offset = i * top_k;
let activation_offset = i * h;
let weights_offset = i * top_k;
let moe_out_offset = i * h;
let gate_res = B::gemv_quant_moe_id_offset(
ctx,
&self.scratch.norm_out,
activation_offset,
gate_stacked,
&self.scratch.selected_ids_buf,
ids_offset,
&mut self.scratch.gate_out_stacked,
top_k,
0,
);
if gate_res.is_err() {
B::copy_slice(
ctx,
&self.scratch.selected_ids_buf,
ids_offset,
&mut self.scratch.ids_buf,
0,
top_k,
);
B::copy_slice(
ctx,
&self.scratch.weights_2d,
weights_offset,
&mut self.scratch.weights_buf,
0,
top_k,
);
B::copy_slice(
ctx,
&self.scratch.norm_out,
activation_offset,
&mut self.scratch.x_single,
0,
h,
);
B::gemv_quant_moe_id(
ctx,
&self.scratch.x_single,
gate_stacked,
&self.scratch.ids_buf,
&mut self.scratch.gate_out_stacked,
top_k,
0,
)?;
B::gemv_quant_moe_id(
ctx,
&self.scratch.x_single,
up_stacked,
&self.scratch.ids_buf,
&mut self.scratch.up_out_stacked,
top_k,
0,
)?;
B::silu_mul_stacked(
ctx,
&self.scratch.gate_out_stacked,
&self.scratch.up_out_stacked,
&mut self.scratch.silu_stacked,
top_k,
inter,
)?;
B::gemv_quant_moe_id(
ctx,
&self.scratch.silu_stacked,
down_stacked,
&self.scratch.ids_buf,
&mut self.scratch.down_out_stacked,
top_k,
inter,
)?;
B::weighted_sum_batched(
ctx,
&self.scratch.down_out_stacked,
&self.scratch.weights_buf,
&mut self.scratch.acc_buf,
1,
top_k,
h,
)?;
B::copy_slice(
ctx,
&self.scratch.acc_buf,
0,
&mut self.scratch.moe_out,
moe_out_offset,
h,
);
continue;
}
B::gemv_quant_moe_id_offset(
ctx,
&self.scratch.norm_out,
activation_offset,
up_stacked,
&self.scratch.selected_ids_buf,
ids_offset,
&mut self.scratch.up_out_stacked,
top_k,
0,
)?;
B::silu_mul_stacked(
ctx,
&self.scratch.gate_out_stacked,
&self.scratch.up_out_stacked,
&mut self.scratch.silu_stacked,
top_k,
inter,
)?;
B::gemv_quant_moe_id_offset(
ctx,
&self.scratch.silu_stacked,
0, down_stacked,
&self.scratch.selected_ids_buf,
ids_offset,
&mut self.scratch.down_out_stacked,
top_k,
inter,
)?;
B::weighted_sum_batched_offset(
ctx,
&self.scratch.down_out_stacked,
&self.scratch.weights_2d,
weights_offset,
&mut self.scratch.moe_out,
moe_out_offset,
1,
top_k,
h,
)?;
}
} else {
moe_forward::<B>(
ctx,
&self.scratch.norm_out,
&self.scratch.router_logits,
&mut self.scratch.moe_out,
m,
h,
self.cfg.expert_intermediate_size,
self.cfg.num_experts,
self.cfg.num_experts_per_tok,
self.cfg.norm_topk_prob,
&moe_layer.experts,
&mut self.scratch.x_single,
&mut self.scratch.acc_buf,
&mut self.scratch.gate_up_buf,
&mut self.scratch.silu_buf,
&mut self.scratch.down_buf,
&self.scratch.zero_hidden,
)?;
}
B::add_inplace(ctx, residual, &self.scratch.moe_out, m * h);
stage_end(moe_t0, ctx, &BD_MOE_US);
Ok(())
}
}
impl<B: Backend> DecoderOnlyLLM for Qwen3MoeModel<B> {
fn config(&self) -> &LlmRuntimeConfig {
&self.runtime_cfg
}
fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
self.ensure_scratch(max_tokens);
self.ensure_kv(cache_id);
const WARMUP_CACHE: &str = "__ferrum_warmup__";
let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
if let Some(caches) = self.kv_caches.remove(WARMUP_CACHE) {
self.kv_free_pool.push(caches);
}
}
fn kv_capacity(&self) -> usize {
let model_max = self.cfg.base.max_seq_len;
const DEFAULT_KV_CAPACITY: usize = 4096;
std::env::var("FERRUM_KV_CAPACITY")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.map(|cap| cap.min(model_max))
.unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
}
fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
self.prefill_internal(cache_id, tokens)
}
fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
self.decode_internal(cache_id, token, pos)
}
fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
let m = batch.len();
let opted_in = std::env::var("FERRUM_MOE_BATCHED").as_deref() == Ok("1");
let threshold = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(12);
if opted_in && m >= threshold {
self.decode_batch_internal(batch)
} else {
batch
.iter()
.map(|(cid, tok, p)| self.decode(cid, *tok, *p))
.collect()
}
}
fn release(&mut self, cache_id: &str) {
let mut ctx = B::new_context();
B::sync(&mut ctx);
B::reset_graph(&mut ctx);
B::sync(&mut ctx);
if let Some(mut caches) = self.kv_caches.remove(cache_id) {
if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
if let Some(c0) = caches.first() {
if !c0.paged_block_indices.is_empty() {
alloc.free(&c0.paged_block_indices);
}
}
for c in caches.iter_mut() {
c.paged_block_indices.clear();
}
}
self.kv_free_pool.push(caches);
}
}
fn reset(&mut self) {
let mut ctx = B::new_context();
B::sync(&mut ctx);
B::reset_graph(&mut ctx);
B::sync(&mut ctx);
self.kv_caches.clear();
self.kv_free_pool.clear();
}
}
#[allow(clippy::too_many_arguments)]
fn moe_forward_stacked_decode_impl<B: Backend>(
ctx: &mut B::Context,
moe_layer: &Qwen3MoeLayerState<B>,
scratch: &mut Qwen3MoeScratch<B>,
h: usize,
inter: usize,
top_k: usize,
n_exp: usize,
norm_topk_prob: bool,
tokens: usize,
residual: &mut B::Buffer,
next_norm_w: Option<&B::Buffer>,
eps: f32,
) -> Result<()> {
let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
let stage_t0 = || -> Option<std::time::Instant> {
if prof {
Some(std::time::Instant::now())
} else {
None
}
};
let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
if let Some(t) = t0 {
B::sync(ctx);
c.fetch_add(
t.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
};
let t0 = stage_t0();
B::route_topk_softmax(
ctx,
&scratch.router_logits,
&mut scratch.ids_buf,
&mut scratch.weights_buf,
tokens,
n_exp,
top_k,
norm_topk_prob,
)?;
stage_end(t0, ctx, &DEC_ROUTE_US);
let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
debug_assert_eq!(
tokens, 1,
"moe_forward_stacked_decode_impl expects tokens=1 (prefill goes through moe_forward_batched_prefill_impl)"
);
let _ = tokens;
{
static FUSED_DISABLED: OnceLock<bool> = OnceLock::new();
let fused_disabled = *FUSED_DISABLED
.get_or_init(|| std::env::var("FERRUM_MOE_FUSED_GATE_UP_SILU").as_deref() == Ok("0"));
let use_fused = B::supports_fused_moe_gate_up_silu() && !fused_disabled;
if use_fused {
let t0 = stage_t0();
B::gemv_quant_moe_id_gate_up_silu(
ctx,
&scratch.norm_out,
gate_stacked,
up_stacked,
&scratch.ids_buf,
&mut scratch.silu_stacked,
top_k,
)?;
stage_end(t0, ctx, &DEC_SILU_US);
} else {
let t0 = stage_t0();
B::gemv_quant_moe_id(
ctx,
&scratch.norm_out,
gate_stacked,
&scratch.ids_buf,
&mut scratch.gate_out_stacked,
top_k,
0, )?;
stage_end(t0, ctx, &DEC_GATE_US);
let t0 = stage_t0();
B::gemv_quant_moe_id(
ctx,
&scratch.norm_out,
up_stacked,
&scratch.ids_buf,
&mut scratch.up_out_stacked,
top_k,
0,
)?;
stage_end(t0, ctx, &DEC_UP_US);
let t0 = stage_t0();
B::silu_mul_stacked(
ctx,
&scratch.gate_out_stacked,
&scratch.up_out_stacked,
&mut scratch.silu_stacked,
top_k,
inter,
)?;
stage_end(t0, ctx, &DEC_SILU_US);
}
let t0 = stage_t0();
B::gemv_quant_moe_id(
ctx,
&scratch.silu_stacked,
down_stacked,
&scratch.ids_buf,
&mut scratch.down_out_stacked,
top_k,
inter,
)?;
stage_end(t0, ctx, &DEC_DOWN_US);
let t0 = stage_t0();
if let Some(nnw) = next_norm_w {
B::weighted_sum_residual_norm_stacked(
ctx,
&scratch.down_out_stacked,
&scratch.weights_buf,
residual,
nnw,
&mut scratch.norm_out,
top_k,
h,
eps,
)?;
} else {
B::weighted_sum_residual_stacked(
ctx,
&scratch.down_out_stacked,
&scratch.weights_buf,
residual,
top_k,
h,
)?;
}
stage_end(t0, ctx, &DEC_WSUM_US);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn moe_forward_batched_prefill_impl<B: Backend>(
ctx: &mut B::Context,
moe_layer: &Qwen3MoeLayerState<B>,
scratch: &mut Qwen3MoeScratch<B>,
h: usize,
inter: usize,
top_k: usize,
n_exp: usize,
norm_topk_prob: bool,
tokens: usize,
) -> Result<()> {
let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
let stage_t0 = || -> Option<std::time::Instant> {
if prof {
Some(std::time::Instant::now())
} else {
None
}
};
let stage_end =
|t0: Option<std::time::Instant>, ctx: &mut B::Context, us: &AtomicU64, n: &AtomicU64| {
if let Some(t) = t0 {
B::sync(ctx);
us.fetch_add(
t.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
n.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
};
let use_gpu_topk = std::env::var("FERRUM_MOE_HOST_TOPK").as_deref() != Ok("1");
let use_indirect_dispatch =
use_gpu_topk && std::env::var("FERRUM_MOE_DIRECT_DISPATCH").as_deref() != Ok("1");
let max_per_expert = if use_gpu_topk {
let t0 = stage_t0();
B::route_topk_softmax(
ctx,
&scratch.router_logits,
&mut scratch.selected_ids_buf,
&mut scratch.weights_2d,
tokens,
n_exp,
top_k,
norm_topk_prob,
)?;
B::compute_ids_tpe_gpu(
ctx,
&scratch.selected_ids_buf,
&mut scratch.tpe_buf,
&mut scratch.ids_2d,
&mut scratch.gate_up_args_buf,
&mut scratch.down_args_buf,
tokens,
n_exp,
top_k,
inter,
h,
)?;
stage_end(
t0,
ctx,
&MOE_PREFILL_HOST_TOPK_US,
&MOE_PREFILL_HOST_TOPK_CALLS,
);
tokens * top_k
} else {
use ferrum_kernels::moe_host::compute_ids_tpe;
let t0 = stage_t0();
B::sync(ctx);
let logits_host = B::to_vec(&scratch.router_logits, tokens * n_exp);
let route = crate::moe::router::route(&logits_host, tokens, n_exp, top_k, norm_topk_prob);
let (tpe_host, ids_host, max_per_expert) =
compute_ids_tpe(&route.expert_ids, n_exp, tokens, top_k);
B::write_i32_into(&mut scratch.tpe_buf, &tpe_host);
B::write_i32_into(&mut scratch.ids_2d, &ids_host);
B::write_f32_into(&mut scratch.weights_2d, &route.expert_weights);
stage_end(
t0,
ctx,
&MOE_PREFILL_HOST_TOPK_US,
&MOE_PREFILL_HOST_TOPK_CALLS,
);
max_per_expert
};
let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
let t0 = stage_t0();
if use_indirect_dispatch {
B::gemm_quant_moe_id_indirect(
ctx,
&scratch.norm_out,
gate_stacked,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.gate_out_stacked,
&scratch.gate_up_args_buf,
1, top_k,
max_per_expert,
tokens,
)?;
} else {
B::gemm_quant_moe_id(
ctx,
&scratch.norm_out,
gate_stacked,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.gate_out_stacked,
1,
top_k,
max_per_expert,
tokens,
)?;
}
stage_end(t0, ctx, &MOE_PREFILL_GATE_US, &MOE_PREFILL_GATE_CALLS);
let t0 = stage_t0();
if use_indirect_dispatch {
B::gemm_quant_moe_id_indirect(
ctx,
&scratch.norm_out,
up_stacked,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.up_out_stacked,
&scratch.gate_up_args_buf,
1,
top_k,
max_per_expert,
tokens,
)?;
} else {
B::gemm_quant_moe_id(
ctx,
&scratch.norm_out,
up_stacked,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.up_out_stacked,
1,
top_k,
max_per_expert,
tokens,
)?;
}
stage_end(t0, ctx, &MOE_PREFILL_UP_US, &MOE_PREFILL_UP_CALLS);
let total_pairs = tokens * top_k;
let t0 = stage_t0();
B::silu_mul_batched(
ctx,
&scratch.gate_out_stacked,
&scratch.up_out_stacked,
&mut scratch.silu_stacked,
total_pairs,
inter,
)?;
stage_end(t0, ctx, &MOE_PREFILL_SILU_US, &MOE_PREFILL_SILU_CALLS);
let t0 = stage_t0();
if use_indirect_dispatch {
B::gemm_quant_moe_id_indirect(
ctx,
&scratch.silu_stacked,
down_stacked,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.down_out_stacked,
&scratch.down_args_buf,
top_k, top_k,
max_per_expert,
tokens,
)?;
} else {
B::gemm_quant_moe_id(
ctx,
&scratch.silu_stacked,
down_stacked,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.down_out_stacked,
top_k,
top_k,
max_per_expert,
tokens,
)?;
}
stage_end(t0, ctx, &MOE_PREFILL_DOWN_US, &MOE_PREFILL_DOWN_CALLS);
let t0 = stage_t0();
B::weighted_sum_batched(
ctx,
&scratch.down_out_stacked,
&scratch.weights_2d,
&mut scratch.moe_out,
tokens,
top_k,
h,
)?;
stage_end(t0, ctx, &MOE_PREFILL_WSUM_US, &MOE_PREFILL_WSUM_CALLS);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn moe_forward_batched_decode_impl<B: Backend>(
ctx: &mut B::Context,
moe_layer: &Qwen3MoeLayerState<B>,
scratch: &mut Qwen3MoeScratch<B>,
h: usize,
inter: usize,
top_k: usize,
n_exp: usize,
norm_topk_prob: bool,
tokens: usize,
) -> Result<()> {
let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
let stage_t0 = || -> Option<std::time::Instant> {
if prof {
Some(std::time::Instant::now())
} else {
None
}
};
let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
if let Some(t) = t0 {
B::sync(ctx);
c.fetch_add(
t.elapsed().as_micros() as u64,
std::sync::atomic::Ordering::Relaxed,
);
}
};
let total_pairs = tokens * top_k;
let t0 = stage_t0();
B::route_topk_softmax(
ctx,
&scratch.router_logits,
&mut scratch.selected_ids_buf,
&mut scratch.weights_2d,
tokens,
n_exp,
top_k,
norm_topk_prob,
)?;
stage_end(t0, ctx, &MOE_BATCHED_DECODE_ROUTE_US);
let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
if B::supports_batched_moe_gate_up_silu() {
let t0 = stage_t0();
B::gemv_quant_moe_id_gate_up_silu_batched(
ctx,
&scratch.norm_out,
gate_stacked,
up_stacked,
&scratch.selected_ids_buf,
&mut scratch.silu_stacked,
tokens,
top_k,
h, 0, )?;
stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
} else {
let t0 = stage_t0();
B::gemv_quant_moe_id_batched(
ctx,
&scratch.norm_out,
gate_stacked,
&scratch.selected_ids_buf,
&mut scratch.gate_out_stacked,
tokens,
top_k,
h,
0,
)?;
stage_end(t0, ctx, &MOE_BATCHED_DECODE_GATE_US);
let t0 = stage_t0();
B::gemv_quant_moe_id_batched(
ctx,
&scratch.norm_out,
up_stacked,
&scratch.selected_ids_buf,
&mut scratch.up_out_stacked,
tokens,
top_k,
h,
0,
)?;
stage_end(t0, ctx, &MOE_BATCHED_DECODE_UP_US);
let t0 = stage_t0();
B::silu_mul_batched(
ctx,
&scratch.gate_out_stacked,
&scratch.up_out_stacked,
&mut scratch.silu_stacked,
total_pairs,
inter,
)?;
stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
}
let t0 = stage_t0();
B::gemv_quant_moe_id_batched(
ctx,
&scratch.silu_stacked,
down_stacked,
&scratch.selected_ids_buf,
&mut scratch.down_out_stacked,
tokens,
top_k,
top_k * inter, inter, )?;
stage_end(t0, ctx, &MOE_BATCHED_DECODE_DOWN_US);
let t0 = stage_t0();
B::weighted_sum_batched(
ctx,
&scratch.down_out_stacked,
&scratch.weights_2d,
&mut scratch.moe_out,
tokens,
top_k,
h,
)?;
stage_end(t0, ctx, &MOE_BATCHED_DECODE_WSUM_US);
Ok(())
}
fn stub_linear<B: Backend>(
out_features: usize,
in_features: usize,
) -> Box<dyn ferrum_quantization::Linear<B>> {
let zeros = vec![0.0f32; out_features * in_features];
Box::new(ferrum_quantization::DenseLinear::<B>::from_rows(
&zeros,
out_features,
in_features,
))
}
fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
let hd = cfg.head_dim;
let half = hd / 2;
let max = cfg.max_seq_len;
let mut cos = vec![0.0f32; max * half];
let mut sin = vec![0.0f32; max * half];
for pos in 0..max {
for i in 0..half {
let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
let angle = pos as f64 * freq;
cos[pos * half + i] = angle.cos() as f32;
sin[pos * half + i] = angle.sin() as f32;
}
}
RopeCache {
cos: B::from_slice(&cos),
sin: B::from_slice(&sin),
}
}