ferrum_models/models/qwen3_moe.rs
1//! `Qwen3MoeModel<B>` — Qwen3-MoE family decoder (Qwen3-30B-A3B and friends).
2//!
3//! Architectural delta vs [`LlamaFamilyModel`]:
4//! * Each transformer layer's FFN is a top-K MoE block instead of a
5//! fused `gate_up_proj → silu → down_proj` MLP.
6//! - One small router linear (`[hidden] → [num_experts]`) picks
7//! top-K experts per token.
8//! - Each expert is itself a fused `gate_up + down` MLP with the
9//! same SwiGLU + RMSNorm structure as the dense path, just with
10//! `expert_intermediate_size` (typically much smaller than the
11//! dense `intermediate_size`).
12//! - Output is the weight-summed combination of the K selected
13//! expert outputs.
14//! * Attention path is unchanged from dense Qwen3 (GQA + QK-norm + RoPE).
15//!
16//! Implementation re-uses the dense layer's attention machinery
17//! verbatim — RMSNorm, fused QKV, QK-norm + RoPE, KV cache append,
18//! flash attention, O-projection, residual + post-norm. The only new
19//! code is the MoE FFN block at the tail of each layer's forward.
20//!
21//! Memory model: experts are loaded as `QuantLinear<B>` per expert,
22//! slicing the on-disk 3-D `ffn_{gate,up,down}_exps.weight` tensors
23//! byte-wise so weights stay compressed (Q4_K / Q6_K). For a 32 GB
24//! Mac to run Qwen3-30B-A3B at all, this is non-negotiable: an
25//! eager-fp32 expert stack would weigh ~110 GB.
26
27use std::collections::HashMap;
28use std::sync::atomic::AtomicU64;
29use std::sync::OnceLock;
30
31use ferrum_kernels::backend::{Backend, KvCache};
32use ferrum_quantization::WeightLoader;
33use ferrum_types::{FerrumError, Result};
34
35use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
36use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyLayer, RopeCache};
37use crate::moe::{moe_forward, ExpertStack};
38use crate::moe_config::Qwen3MoeConfig;
39
40// Decode-side per-op profile counters — same names as the dense path
41// so existing tooling (`FERRUM_DECODE_OP_PROFILE=1` log scrapers) keeps
42// working without a separate switch for MoE.
43static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
44static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
45static MOE_TIME_US: AtomicU64 = AtomicU64::new(0);
46static MOE_CALLS: AtomicU64 = AtomicU64::new(0);
47
48// Fine-grained decode-only counters, populated by
49// `moe_forward_stacked_decode_impl` when FERRUM_DECODE_OP_PROFILE is set.
50// Each is per-layer summed over the layers in one decode token; drained
51// at the bottom of `decode_internal`.
52static DEC_ROUTE_US: AtomicU64 = AtomicU64::new(0);
53static DEC_GATE_US: AtomicU64 = AtomicU64::new(0);
54static DEC_UP_US: AtomicU64 = AtomicU64::new(0);
55static DEC_SILU_US: AtomicU64 = AtomicU64::new(0);
56static DEC_DOWN_US: AtomicU64 = AtomicU64::new(0);
57static DEC_WSUM_US: AtomicU64 = AtomicU64::new(0);
58// Single-shot per decode token (not per-layer).
59static DEC_EMBED_US: AtomicU64 = AtomicU64::new(0);
60static DEC_FINAL_NORM_US: AtomicU64 = AtomicU64::new(0);
61static DEC_LM_HEAD_US: AtomicU64 = AtomicU64::new(0);
62
63// MoE batched-prefill sub-stage counters (gate / up / down mul_mm_id +
64// silu + weighted_sum + host topk). Same FERRUM_DECODE_OP_PROFILE gate.
65static MOE_PREFILL_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
66static MOE_PREFILL_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
67static MOE_PREFILL_GATE_US: AtomicU64 = AtomicU64::new(0);
68static MOE_PREFILL_GATE_CALLS: AtomicU64 = AtomicU64::new(0);
69static MOE_PREFILL_UP_US: AtomicU64 = AtomicU64::new(0);
70static MOE_PREFILL_UP_CALLS: AtomicU64 = AtomicU64::new(0);
71static MOE_PREFILL_SILU_US: AtomicU64 = AtomicU64::new(0);
72static MOE_PREFILL_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
73static MOE_PREFILL_DOWN_US: AtomicU64 = AtomicU64::new(0);
74static MOE_PREFILL_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
75static MOE_PREFILL_WSUM_US: AtomicU64 = AtomicU64::new(0);
76static MOE_PREFILL_WSUM_CALLS: AtomicU64 = AtomicU64::new(0);
77
78// MoE batched-DECODE sub-stage counters (small-m path that uses the
79// batched-pair GEMV in place of the per-token loop).
80static MOE_BATCHED_DECODE_ROUTE_US: AtomicU64 = AtomicU64::new(0);
81static MOE_BATCHED_DECODE_GATE_US: AtomicU64 = AtomicU64::new(0);
82static MOE_BATCHED_DECODE_UP_US: AtomicU64 = AtomicU64::new(0);
83static MOE_BATCHED_DECODE_SILU_US: AtomicU64 = AtomicU64::new(0);
84static MOE_BATCHED_DECODE_DOWN_US: AtomicU64 = AtomicU64::new(0);
85static MOE_BATCHED_DECODE_WSUM_US: AtomicU64 = AtomicU64::new(0);
86
87// Coarse stage counters for `forward_layer_batched_decode` so we can
88// see where the time goes without per-op instrumentation. Summed
89// across all layers in one decode_batch_internal call.
90static BD_DENSE_US: AtomicU64 = AtomicU64::new(0); // rms_norm + qkv_proj + split_qkv + o_proj + fused_add_rms_norm
91static BD_ATTN_PERITEM_US: AtomicU64 = AtomicU64::new(0); // the for-i in 0..m attention loop (incl. plumbing)
92static BD_MOE_US: AtomicU64 = AtomicU64::new(0); // router + MoE FFN + residual add
93static BD_LAYER_CALLS: AtomicU64 = AtomicU64::new(0);
94
95/// Per-layer MoE state: router linear (small) + per-expert MLP stack.
96pub struct Qwen3MoeLayerState<B: Backend> {
97 /// Router projection `[hidden] → [num_experts]` — tiny, never sparse,
98 /// always runs the full GEMV.
99 pub router: Box<dyn ferrum_quantization::Linear<B>>,
100 /// Per-expert weight stack. Each entry's `gate_up` is the fused
101 /// `[gate; up]` projection; `down` is the post-SwiGLU output proj.
102 pub experts: ExpertStack<B>,
103}
104
105/// Reusable scratch buffers for the MoE forward path. All sized at
106/// allocation time and reused across layers / forward calls.
107pub struct Qwen3MoeScratch<B: Backend> {
108 /// See [`crate::models::llama_family::LlamaFamilyScratch`] for the
109 /// attention scratch — we re-use those names verbatim.
110 pub residual: Option<B::Buffer>,
111 pub norm_out: B::Buffer,
112 pub qkv_out: B::Buffer,
113 pub q_buf: B::Buffer,
114 pub k_buf: B::Buffer,
115 pub v_buf: B::Buffer,
116 pub q_head_major: B::Buffer,
117 pub k_head_major: B::Buffer,
118 pub v_head_major: B::Buffer,
119 pub attn_head_major_out: B::Buffer,
120 pub attn_flat: B::Buffer,
121 pub o_proj_out: B::Buffer,
122
123 // ── MoE-specific scratch ─────────────────────────────────────────
124 /// Router logits for the whole batch: `[max_tokens, num_experts]`.
125 pub router_logits: B::Buffer,
126 /// Per-(token, expert) gate||up projection output — `[2 * expert_inter]`.
127 pub gate_up_buf: B::Buffer,
128 /// SiLU(gate) * up scratch — `[expert_inter]`.
129 pub silu_buf: B::Buffer,
130 /// Per-(token, expert) down-projection output — `[hidden]`.
131 pub down_buf: B::Buffer,
132 /// Per-token input row scratch — `[hidden]`. Holds the post-RMSNorm
133 /// activation slice that the per-(expert) gate_up gemv reads, kept
134 /// stable across the entire top_k loop for one token.
135 pub x_single: B::Buffer,
136 /// Per-token output accumulator — `[hidden]`. Holds the running
137 /// `Σ_k weight_k · expert_k(x[b])` sum that grows across the top_k
138 /// loop and is flushed to `moe_out[b]` once per token.
139 pub acc_buf: B::Buffer,
140 /// MoE output `[max_tokens, hidden]`. Zeroed each forward.
141 pub moe_out: B::Buffer,
142 /// Pre-allocated `[hidden]` zero scratch — `acc_buf` is reset to
143 /// this each token without going through `B::from_slice` on the
144 /// hot path.
145 pub zero_hidden: B::Buffer,
146
147 // ── MoE batched-fast-path scratch (Metal `gemv_q*kw_moe_id_f32` /
148 // `gemm_q*kw_moe_id_f32`) ─────────────────────────────────────
149 //
150 // Sized for `max_tokens * top_k * X` so the same buffers cover both
151 // decode (m=1, uses the first `top_k * X` slice) and prefill
152 // (m>1, uses the full `max_tokens * top_k * X`). Decode-only
153 // workloads pay no extra memory because `max_tokens` was 1 there.
154 /// `[max_tokens * top_k * expert_inter]` — gate gemm output per pair.
155 pub gate_out_stacked: B::Buffer,
156 /// `[max_tokens * top_k * expert_inter]` — up gemm output per pair.
157 pub up_out_stacked: B::Buffer,
158 /// `[max_tokens * top_k * expert_inter]` — SiLU(gate)·up per pair.
159 pub silu_stacked: B::Buffer,
160 /// `[max_tokens * top_k * hidden]` — down gemm output per pair.
161 pub down_out_stacked: B::Buffer,
162 /// `[top_k]` i32 expert IDs for the current token (decode reuses;
163 /// prefill writes per-pair indices into `ids_2d` instead).
164 pub ids_buf: B::Buffer,
165 /// `[top_k]` f32 router combine weights for the current decode
166 /// token. Decode hot-path uses `write_f32_into` to update.
167 pub weights_buf: B::Buffer,
168 /// `[max_tokens * top_k]` i32 — flat selected-expert IDs from the
169 /// GPU router for the prefill batch. Consumed by `compute_ids_tpe_gpu`
170 /// to bucket pairs by expert into `tpe_buf` / `ids_2d`.
171 pub selected_ids_buf: B::Buffer,
172 /// `[3]` u32 indirect-dispatch args (`grid_x, grid_y, grid_z`) for
173 /// the gate / up MoE GEMM. Written by `compute_ids_tpe_gpu` so the
174 /// consumer GEMM grid covers exactly `max(tpe[e])` columns instead
175 /// of the worst-case `tokens * top_k`.
176 pub gate_up_args_buf: B::Buffer,
177 /// Same shape as `gate_up_args_buf` but for the down MoE GEMM
178 /// (different `grid_y` because down's `M = hidden_size` vs gate/up's
179 /// `M = expert_intermediate_size`).
180 pub down_args_buf: B::Buffer,
181 /// `[num_experts * max_per_expert_max]` i32 — per-expert pair
182 /// index lists for prefill 2-D mul_mm_id. `max_per_expert_max`
183 /// is bounded by `max_tokens * top_k` (worst-case: one expert
184 /// gets every pair). Sized at scratch alloc time.
185 pub ids_2d: B::Buffer,
186 /// `[num_experts]` i32 — `tpe[e]` = number of pairs assigned to
187 /// expert `e`. Companion to `ids_2d`.
188 pub tpe_buf: B::Buffer,
189 /// `[max_tokens * top_k]` f32 — combine weights per pair, in
190 /// natural `[batch, top_k]` layout for `weighted_sum_batched`.
191 pub weights_2d: B::Buffer,
192
193 // ── Final-token / lm_head outputs ────────────────────────────────
194 pub last_hidden: B::Buffer,
195 pub last_normed: B::Buffer,
196 pub logits: B::Buffer,
197 pub batch_logits: B::Buffer,
198
199 // ── Per-item single-token buffers for decode_batch (Phase 4b) ────
200 //
201 // The batched-decode path runs M GEMMs at m=M (qkv_proj / o_proj /
202 // router / MoE expert mul_mm_id) but attention stays a per-item loop
203 // (each cache_id has its own contiguous K/V buffer — no way to fan
204 // M items into a single attention dispatch without paged KV). These
205 // 1-token-shaped scratches hold the per-item slice during the loop:
206 // `copy_slice` extracts q/k/v from the batched buffers, qk_norm_rope
207 // writes head-major into _single, kv_cache_append + flash_attention
208 // run on it, then copy_slice writes back into attn_flat[i*q_dim].
209 //
210 // None until `enable_batched_decode_scratch` is called from
211 // `ensure_kv` once we know we'll be doing multi-seq decode.
212 pub q_single: Option<B::Buffer>,
213 pub k_single: Option<B::Buffer>,
214 pub v_single: Option<B::Buffer>,
215 pub q_head_major_single: Option<B::Buffer>,
216 pub k_head_major_single: Option<B::Buffer>,
217 pub v_head_major_single: Option<B::Buffer>,
218 pub attn_head_major_single: Option<B::Buffer>,
219
220 // ── Paged batched dispatch scratch ──────────────────────────────────
221 //
222 // Mirrors the same fields on `LlamaFamilyScratch`. `Some` only when
223 // `FERRUM_METAL_PAGED_KV=1` and `enable_paged_batch` was called once
224 // we know the pool dimensions. Sized for `FERRUM_PAGED_MAX_SEQS ×
225 // q_dim` so the multi-seq decode path can fan in M items' Q into a
226 // single batched buffer for one `paged_decode_attention(num_seqs=M)`
227 // call instead of running M sequential m=1 attentions.
228 pub paged_batch_q: Option<B::Buffer>,
229 pub paged_batch_o: Option<B::Buffer>,
230 pub paged_batch_block_tables: Option<B::Buffer>,
231 pub paged_batch_context_lens: Option<B::Buffer>,
232 pub paged_max_blocks_per_seq: usize,
233
234 pub max_tokens: usize,
235}
236
237impl<B: Backend> Qwen3MoeScratch<B> {
238 fn alloc(cfg: &Qwen3MoeConfig, max_tokens: usize) -> Self {
239 let h = cfg.base.hidden_size;
240 let q_dim = cfg.base.num_heads * cfg.base.head_dim;
241 let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
242 let qkv_dim = q_dim + 2 * kv_dim;
243 let t = max_tokens;
244 let inter = cfg.expert_intermediate_size;
245 let n_exp = cfg.num_experts;
246 let vocab = cfg.base.vocab_size;
247 Self {
248 residual: Some(B::alloc(t * h)),
249 norm_out: B::alloc(t * h),
250 qkv_out: B::alloc(t * qkv_dim),
251 q_buf: B::alloc(t * q_dim),
252 k_buf: B::alloc(t * kv_dim),
253 v_buf: B::alloc(t * kv_dim),
254 q_head_major: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
255 k_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
256 v_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
257 attn_head_major_out: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
258 attn_flat: B::alloc(t * q_dim),
259 o_proj_out: B::alloc(t * h),
260 router_logits: B::alloc(t * n_exp),
261 gate_up_buf: B::alloc(2 * inter),
262 silu_buf: B::alloc(inter),
263 down_buf: B::alloc(h),
264 x_single: B::alloc(h),
265 acc_buf: B::alloc(h),
266 moe_out: B::alloc(t * h),
267 zero_hidden: B::from_slice(&vec![0.0f32; h]),
268 gate_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
269 up_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
270 silu_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
271 down_out_stacked: B::alloc(t * cfg.num_experts_per_tok * h),
272 ids_buf: B::from_slice_i32(&vec![0i32; cfg.num_experts_per_tok]),
273 weights_buf: B::from_slice(&vec![0.0f32; cfg.num_experts_per_tok]),
274 selected_ids_buf: B::from_slice_i32(&vec![0i32; t * cfg.num_experts_per_tok]),
275 // 3 u32s per indirect args buffer; allocated as 3 i32s so we
276 // can reuse `from_slice_i32`. The kernel writes them as
277 // `device uint *` and the bit pattern is consumed by
278 // `dispatch_thread_groups_indirect`.
279 gate_up_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
280 down_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
281 ids_2d: B::from_slice_i32(&vec![0i32; n_exp * t * cfg.num_experts_per_tok]),
282 tpe_buf: B::from_slice_i32(&vec![0i32; n_exp]),
283 weights_2d: B::from_slice(&vec![0.0f32; t * cfg.num_experts_per_tok]),
284 last_hidden: B::alloc(h),
285 last_normed: B::alloc(h),
286 logits: B::alloc(vocab),
287 batch_logits: B::alloc(t * vocab),
288 // Lazily-allocated; `enable_batched_decode_scratch` populates
289 // these the first time decode_batch is called with M > 1.
290 q_single: None,
291 k_single: None,
292 v_single: None,
293 q_head_major_single: None,
294 k_head_major_single: None,
295 v_head_major_single: None,
296 attn_head_major_single: None,
297 // Lazily-allocated; `enable_paged_batch` populates these when
298 // FERRUM_METAL_PAGED_KV=1 + we know the pool dimensions.
299 paged_batch_q: None,
300 paged_batch_o: None,
301 paged_batch_block_tables: None,
302 paged_batch_context_lens: None,
303 paged_max_blocks_per_seq: 0,
304 max_tokens: t,
305 }
306 }
307
308 /// Allocate scratch for paged batched dispatch. Mirrors
309 /// `LlamaFamilyScratch::enable_paged_batch`. Idempotent.
310 fn enable_paged_batch(
311 &mut self,
312 cfg: &Qwen3MoeConfig,
313 max_seqs: usize,
314 max_blocks_per_seq: usize,
315 ) {
316 if self.paged_batch_q.is_some() {
317 return;
318 }
319 let q_dim = cfg.base.num_heads * cfg.base.head_dim;
320 self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
321 self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
322 self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
323 self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
324 self.paged_max_blocks_per_seq = max_blocks_per_seq;
325 }
326
327 /// Allocate the per-item single-token scratch buffers used by
328 /// `forward_layer_batched_decode`. Idempotent.
329 fn enable_batched_decode_scratch(&mut self, cfg: &Qwen3MoeConfig) {
330 if self.q_single.is_some() {
331 return;
332 }
333 let q_dim = cfg.base.num_heads * cfg.base.head_dim;
334 let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
335 self.q_single = Some(B::alloc(q_dim));
336 self.k_single = Some(B::alloc(kv_dim));
337 self.v_single = Some(B::alloc(kv_dim));
338 self.q_head_major_single = Some(B::alloc(q_dim));
339 self.k_head_major_single = Some(B::alloc(kv_dim));
340 self.v_head_major_single = Some(B::alloc(kv_dim));
341 self.attn_head_major_single = Some(B::alloc(q_dim));
342 }
343}
344
345/// Qwen3-MoE decoder model.
346///
347/// Holds the same per-layer attention weights as [`LlamaFamilyModel`]
348/// plus a [`Qwen3MoeLayerState`] per layer for the MoE FFN. Routing,
349/// expert dispatch, and weighted combine all happen inside
350/// [`moe_forward`]; this struct only owns the storage and orchestrates
351/// the per-layer call sequence.
352pub struct Qwen3MoeModel<B: Backend> {
353 pub cfg: Qwen3MoeConfig,
354 pub runtime_cfg: LlmRuntimeConfig,
355
356 pub embed: B::Buffer,
357 /// Per-layer attention weights (re-uses dense `LlamaFamilyLayer`).
358 pub attn_layers: Vec<LlamaFamilyLayer<B>>,
359 /// Per-layer MoE state (router + expert stack).
360 pub moe_layers: Vec<Qwen3MoeLayerState<B>>,
361 pub final_norm_w: B::Buffer,
362 pub lm_head: Box<dyn ferrum_quantization::Linear<B>>,
363
364 pub rope: RopeCache<B>,
365 pub scratch: Qwen3MoeScratch<B>,
366
367 pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
368 kv_free_pool: Vec<Vec<KvCache<B>>>,
369
370 // ── Paged-KV multi-seq state ────────────────────────────────────────
371 //
372 // Mirrors `LlamaFamilyModel`. Only populated when
373 // `FERRUM_METAL_PAGED_KV=1`. Kv_caches entries become metadata-only
374 // views (block_table + context_lens) into the shared `paged_pools`.
375 pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
376 pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
377}
378
379impl<B: Backend> Qwen3MoeModel<B> {
380 /// Build a Qwen3-MoE model from a generic `WeightLoader<B>` plus a
381 /// GGUF reader for the experts (which `WeightLoader` doesn't model
382 /// directly — its API is rank-2 only).
383 ///
384 /// `loader` provides: token embedding, attention projections, layer
385 /// norms, lm_head — all the rank-2 weights.
386 /// `gguf` provides: the rank-3 expert tensors, sliced per-expert
387 /// inside [`ExpertStack::load_from_gguf`].
388 pub fn new(
389 cfg: Qwen3MoeConfig,
390 loader: &dyn WeightLoader<B>,
391 gguf: &ferrum_quantization::gguf::GgufFile,
392 ) -> Result<Self> {
393 {
394 let mut ctx = B::new_context();
395 B::reset_graph(&mut ctx);
396 }
397 let rope = build_rope_cache::<B>(&cfg.base);
398 let scratch = Qwen3MoeScratch::alloc(&cfg, 1);
399
400 let embed = loader.load_tensor("model.embed_tokens.weight")?;
401
402 let mut attn_layers = Vec::with_capacity(cfg.base.num_layers);
403 let mut moe_layers = Vec::with_capacity(cfg.base.num_layers);
404 for li in 0..cfg.base.num_layers {
405 let prefix = format!("model.layers.{li}");
406 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
407 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
408 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
409 let post_ln_w =
410 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
411
412 // Dense gate_up_proj / down_proj are absent in MoE GGUFs —
413 // we synthesise stub Linears so the LlamaFamilyLayer struct
414 // type-checks. They're never invoked because forward_layer
415 // calls the MoE path. Cheap: tiny zero-sized DenseLinears.
416 let gate_up_proj: Box<dyn ferrum_quantization::Linear<B>> =
417 stub_linear::<B>(2 * cfg.expert_intermediate_size, cfg.base.hidden_size);
418 let down_proj: Box<dyn ferrum_quantization::Linear<B>> =
419 stub_linear::<B>(cfg.base.hidden_size, cfg.expert_intermediate_size);
420
421 let (q_norm_w, k_norm_w) = if cfg.base.has_qk_norm {
422 let q = loader
423 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
424 .ok();
425 let k = loader
426 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
427 .ok();
428 (q, k)
429 } else {
430 (None, None)
431 };
432
433 attn_layers.push(LlamaFamilyLayer {
434 input_ln_w,
435 qkv_proj,
436 q_norm_w,
437 k_norm_w,
438 o_proj,
439 post_ln_w,
440 gate_up_proj,
441 down_proj,
442 });
443
444 // Router lives at `model.layers.{li}.mlp.router.weight` in
445 // ferrum-name space (see ferrum_to_gguf mapping). It's a
446 // plain rank-2 linear so the standard loader path covers
447 // it without going through the MoE-specific GGUF helper.
448 let router = loader.load_linear(&format!("{prefix}.mlp.router"))?;
449 if router.in_features() != cfg.base.hidden_size {
450 return Err(FerrumError::model(format!(
451 "router layer {li}: in_features {} != hidden {}",
452 router.in_features(),
453 cfg.base.hidden_size
454 )));
455 }
456 if router.out_features() != cfg.num_experts {
457 return Err(FerrumError::model(format!(
458 "router layer {li}: out_features {} != num_experts {}",
459 router.out_features(),
460 cfg.num_experts
461 )));
462 }
463
464 let experts = ExpertStack::<B>::load_from_gguf(
465 gguf,
466 li,
467 cfg.num_experts,
468 cfg.base.hidden_size,
469 cfg.expert_intermediate_size,
470 )?;
471
472 moe_layers.push(Qwen3MoeLayerState { router, experts });
473 }
474
475 let final_norm_w = loader.load_tensor("model.norm.weight")?;
476 let lm_head = if loader.has_tensor("lm_head.weight") {
477 loader.load_linear("lm_head")?
478 } else {
479 // Tied embeddings — same as dense path.
480 tracing::info!(
481 "Qwen3MoeModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
482 );
483 loader.load_linear("model.embed_tokens")?
484 };
485
486 let runtime_cfg = cfg.base.to_runtime();
487 Ok(Self {
488 cfg,
489 runtime_cfg,
490 embed,
491 attn_layers,
492 moe_layers,
493 final_norm_w,
494 lm_head,
495 rope,
496 scratch,
497 kv_caches: HashMap::new(),
498 kv_free_pool: Vec::new(),
499 paged_pools: None,
500 paged_block_alloc: None,
501 })
502 }
503
504 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
505 if self.scratch.max_tokens < tokens {
506 {
507 let mut ctx = B::new_context();
508 B::reset_graph(&mut ctx);
509 }
510 self.scratch = Qwen3MoeScratch::alloc(&self.cfg, tokens);
511 }
512 }
513
514 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
515 if self.kv_caches.contains_key(cache_id) {
516 return;
517 }
518 let nkv = self.cfg.base.num_kv_heads;
519 let hd = self.cfg.base.head_dim;
520 // 512 in 0.7.2 — same value the published bench used to hit 79
521 // tok/s at c=16 on this exact MoE model. See
522 // `LlamaFamilyModel::ensure_kv` for the full rationale.
523 let model_max = self.cfg.base.max_seq_len;
524 const DEFAULT_KV_CAPACITY: usize = 512;
525 let max = std::env::var("FERRUM_KV_CAPACITY")
526 .ok()
527 .and_then(|s| s.parse::<usize>().ok())
528 .map(|cap| cap.min(model_max))
529 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
530
531 // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches caches into
532 // block-table-indirect layout. Mirrors LlamaFamilyModel's path so
533 // the existing `paged_decode_attention` Metal kernel can fire
534 // once at num_seqs=m for batched decode (replacing the per-item
535 // attention loop that currently dominates `attn_peritem` in the
536 // c=16 profile).
537 // Default ON when the backend supports paged-KV (Metal). Users
538 // can force off with `FERRUM_METAL_PAGED_KV=0`. The flag was
539 // opt-in pre-0.7.2; flipping the default so default `ferrum
540 // serve` matches the bench-quality numbers without requiring
541 // env-var knowledge.
542 let paged = std::env::var("FERRUM_METAL_PAGED_KV")
543 .map(|v| v != "0")
544 .unwrap_or_else(|_| B::supports_paged_kv());
545 const PAGED_BLOCK_SIZE: usize = 16;
546
547 // Default 32: covers c=16 burst with 2× headroom for the
548 // fresh-cache-id-per-request pattern that bench/server harnesses
549 // use. Pool memory unchanged from pre-0.7.2 default because
550 // DEFAULT_KV_CAPACITY dropped 4096 → 2048 in lockstep.
551 let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
552 .ok()
553 .and_then(|s| s.parse::<usize>().ok())
554 .unwrap_or(32);
555 let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
556 let total_pool_blocks = max_seqs * max_blocks_per_seq;
557
558 // Lazy-allocate the shared paged pools on the first paged
559 // ensure_kv call.
560 if paged && self.paged_pools.is_none() {
561 let mut pools = Vec::with_capacity(self.cfg.base.num_layers);
562 for _ in 0..self.cfg.base.num_layers {
563 let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
564 pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
565 }
566 self.paged_pools = Some(pools);
567 self.paged_block_alloc = Some(std::sync::Mutex::new(
568 crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
569 ));
570 }
571 if paged {
572 self.scratch
573 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
574 }
575
576 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
577 (0..self.cfg.base.num_layers)
578 .map(|_| {
579 if paged {
580 // Paged mode: cache holds metadata only. K/V are
581 // 1-element placeholders. Real data lives in
582 // `self.paged_pools[li].{k,v}`.
583 let mut block_table = B::alloc_u32(max_blocks_per_seq);
584 let _ = &mut block_table; // suppress unused-mut on backends that no-op write_u32
585 let mut context_lens = B::alloc_u32(1);
586 let mut bt_ctx = B::new_context();
587 B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
588 B::sync(&mut bt_ctx);
589 KvCache {
590 k: B::alloc(1),
591 v: B::alloc(1),
592 len: 0,
593 capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
594 num_kv_heads: nkv,
595 head_dim: hd,
596 block_size: PAGED_BLOCK_SIZE,
597 block_table: Some(block_table),
598 context_lens: Some(context_lens),
599 paged_block_indices: Vec::new(),
600 }
601 } else {
602 KvCache {
603 k: B::alloc(nkv * max * hd),
604 v: B::alloc(nkv * max * hd),
605 len: 0,
606 capacity: max,
607 num_kv_heads: nkv,
608 head_dim: hd,
609 block_size: 0,
610 block_table: None,
611 context_lens: None,
612 paged_block_indices: Vec::new(),
613 }
614 }
615 })
616 .collect()
617 });
618
619 // Allocate physical blocks for THIS cache_id from the shared pool.
620 if paged {
621 let alloc_arc = self
622 .paged_block_alloc
623 .as_ref()
624 .expect("paged_block_alloc must be initialised when paged=true");
625 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
626 let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
627 Ok(idx) => idx,
628 Err(e) => {
629 drop(alloc);
630 self.kv_free_pool.push(caches);
631 eprintln!(
632 "[ferrum] paged KV pool exhausted on ensure_kv for \
633 cache_id={cache_id:?}: {e}. Increase \
634 FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
635 throttle concurrent requests.",
636 );
637 return;
638 }
639 };
640 let mut padded = block_indices.clone();
641 padded.resize(max_blocks_per_seq, 0);
642 let mut ctx_tmp = B::new_context();
643 for c in caches.iter_mut() {
644 if let Some(bt) = c.block_table.as_mut() {
645 B::write_u32(&mut ctx_tmp, bt, &padded);
646 }
647 c.paged_block_indices = block_indices.clone();
648 }
649 B::sync(&mut ctx_tmp);
650 }
651
652 for c in caches.iter_mut() {
653 c.len = 0;
654 if let Some(cl) = c.context_lens.as_mut() {
655 let mut ctx_tmp = B::new_context();
656 B::write_u32(&mut ctx_tmp, cl, &[0u32]);
657 B::sync(&mut ctx_tmp);
658 }
659 }
660 self.kv_caches.insert(cache_id.to_string(), caches);
661 }
662
663 /// Run one full transformer layer (attention + MoE FFN).
664 pub(crate) fn forward_layer(
665 &mut self,
666 ctx: &mut B::Context,
667 li: usize,
668 cache_id: &str,
669 residual: &mut B::Buffer,
670 pos_offset: usize,
671 tokens: usize,
672 // If `Some(idx)` and we land on the decode fast path, fold the
673 // next layer's leading rms_norm into this layer's MoE tail
674 // (cross-layer norm fusion). The next layer's caller must pass
675 // `prev_did_norm_fusion = true` so it skips its own rms_norm.
676 next_layer_idx: Option<usize>,
677 // If `true`, skip step 1's input rms_norm — the previous
678 // layer's tail already populated `scratch.norm_out`.
679 prev_did_norm_fusion: bool,
680 ) -> Result<bool> {
681 let cfg_base = &self.cfg.base;
682 let h = cfg_base.hidden_size;
683 let nh = cfg_base.num_heads;
684 let nkv = cfg_base.num_kv_heads;
685 let hd = cfg_base.head_dim;
686 let eps = cfg_base.rms_norm_eps;
687 let q_dim = nh * hd;
688 let kv_dim = nkv * hd;
689 let attn_layer = &self.attn_layers[li];
690 let moe_layer = &self.moe_layers[li];
691
692 let attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
693 B::sync(ctx);
694 Some(std::time::Instant::now())
695 } else {
696 None
697 };
698
699 // 1. Input RMSNorm — skipped when the previous layer's MoE tail
700 // fused this norm via `weighted_sum_residual_norm_stacked`.
701 if !prev_did_norm_fusion {
702 B::rms_norm(
703 ctx,
704 residual,
705 &attn_layer.input_ln_w,
706 eps,
707 &mut self.scratch.norm_out,
708 tokens,
709 h,
710 );
711 }
712
713 // 2. Fused QKV
714 attn_layer.qkv_proj.forward(
715 ctx,
716 &self.scratch.norm_out,
717 &mut self.scratch.qkv_out,
718 tokens,
719 );
720
721 // 3-4. Fused split-QKV + QK-norm + RoPE + head-major transpose.
722 //
723 // One Metal dispatch replaces (split_qkv → 3× qk_norm_rope), the
724 // four-launch chain that used to dominate the attention prelude.
725 // Reads qkv_out once, writes head-major Q/K (norm+RoPE) and V
726 // (transpose only) directly into attention scratch. Saves 3
727 // dispatches per layer (×48 = 144 dispatches per decode token).
728 //
729 // CPU and other backends without the fused kernel return
730 // Unsupported and we fall through to the original four-launch
731 // path. q_buf / k_buf / v_buf stay in scratch because that path
732 // and the per-expert MoE fallback still want them.
733 let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
734 let dummy = &attn_layer.input_ln_w;
735 let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy);
736 let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy);
737
738 // 5. Grab the per-layer KV cache up front — the deepest fused
739 // variant writes K/V straight into it, avoiding a trailing
740 // `kv_cache_append_head_major` dispatch.
741 //
742 // Paged mode: extract a raw pointer to the layer's pool buffers
743 // BEFORE the &mut cache borrow, so we can pass &mut to the
744 // paged kernel below without holding two simultaneous mutable
745 // borrows on `self`. Safety: `paged_pools` is allocated once at
746 // first ensure_kv call and never resized; the only concurrent
747 // mutation is the pool's own kernel writes (sequenced via
748 // command buffers), so the raw pointer remains valid for the
749 // duration of this layer call.
750 let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
751 if let Some(pools) = self.paged_pools.as_mut() {
752 let pool = &mut pools[li];
753 Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
754 } else {
755 None
756 };
757 let caches = self
758 .kv_caches
759 .get_mut(cache_id)
760 .expect("ensure_kv must be called before forward_layer");
761 let cache = &mut caches[li];
762 let cache_len_before = cache.len;
763 let cache_capacity = cache.capacity;
764
765 // Defense in depth: refuse to write past the KV buffer. Silent
766 // overflow has visible failure modes (garbage output, stale token
767 // attention, slowdowns from reading uninitialised memory). The
768 // graceful path is the caller pre-checking via `kv_capacity()` and
769 // either compacting or refusing the request; this panic only
770 // fires when that contract is broken.
771 if cache_len_before + tokens > cache_capacity {
772 panic!(
773 "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
774 cache_len_before + tokens
775 );
776 }
777
778 // Try the deepest fusion: fused split-QKV-norm-rope that writes
779 // K/V directly into the cache slot. Paged mode writes into the
780 // shared pool via block_table indirection; contiguous mode
781 // writes into the per-cache_id k/v buffers directly.
782 let used_qkv_into_cache = if cache.block_size > 0 {
783 // Paged path.
784 let bt = cache
785 .block_table
786 .as_ref()
787 .expect("paged cache missing block_table");
788 let num_blocks_per_seq = cache.capacity / cache.block_size;
789 let (pool_k_ptr, pool_v_ptr) =
790 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
791 // SAFETY: pools allocated-once, see paged_pool_ptr setup above.
792 let pool_k = unsafe { &mut *pool_k_ptr };
793 let pool_v = unsafe { &mut *pool_v_ptr };
794 B::split_qkv_norm_rope_into_paged_cache(
795 ctx,
796 &self.scratch.qkv_out,
797 0,
798 q_norm_w,
799 k_norm_w,
800 &self.rope.cos,
801 &self.rope.sin,
802 &mut self.scratch.q_head_major,
803 0,
804 pool_k,
805 pool_v,
806 bt,
807 tokens,
808 nh,
809 nkv,
810 hd,
811 pos_offset,
812 eps,
813 qk_mode,
814 cache_len_before,
815 cache.block_size,
816 num_blocks_per_seq,
817 )
818 .is_ok()
819 } else {
820 B::split_qkv_norm_rope_into_cache(
821 ctx,
822 &self.scratch.qkv_out,
823 q_norm_w,
824 k_norm_w,
825 &self.rope.cos,
826 &self.rope.sin,
827 &mut self.scratch.q_head_major,
828 &mut cache.k,
829 &mut cache.v,
830 tokens,
831 nh,
832 nkv,
833 hd,
834 pos_offset,
835 eps,
836 qk_mode,
837 cache_len_before,
838 cache_capacity,
839 )
840 .is_ok()
841 };
842 if !used_qkv_into_cache {
843 // Fallback 1: fused split-QKV-norm-rope to head-major scratch
844 // (Metal pre-decode-fusion path), then explicit cache append.
845 let used_fused_qkv = B::split_qkv_norm_rope(
846 ctx,
847 &self.scratch.qkv_out,
848 q_norm_w,
849 k_norm_w,
850 &self.rope.cos,
851 &self.rope.sin,
852 &mut self.scratch.q_head_major,
853 &mut self.scratch.k_head_major,
854 &mut self.scratch.v_head_major,
855 tokens,
856 nh,
857 nkv,
858 hd,
859 pos_offset,
860 eps,
861 qk_mode,
862 )
863 .is_ok();
864 if !used_fused_qkv {
865 // Fallback 2: original four-launch chain.
866 B::split_qkv(
867 ctx,
868 &self.scratch.qkv_out,
869 &mut self.scratch.q_buf,
870 &mut self.scratch.k_buf,
871 &mut self.scratch.v_buf,
872 tokens,
873 q_dim,
874 kv_dim,
875 );
876 B::qk_norm_rope(
877 ctx,
878 &self.scratch.q_buf,
879 q_norm_w,
880 &self.rope.cos,
881 &self.rope.sin,
882 &mut self.scratch.q_head_major,
883 tokens,
884 nh,
885 hd,
886 pos_offset,
887 eps,
888 qk_mode,
889 );
890 B::qk_norm_rope(
891 ctx,
892 &self.scratch.k_buf,
893 k_norm_w,
894 &self.rope.cos,
895 &self.rope.sin,
896 &mut self.scratch.k_head_major,
897 tokens,
898 nkv,
899 hd,
900 pos_offset,
901 eps,
902 qk_mode,
903 );
904 B::qk_norm_rope(
905 ctx,
906 &self.scratch.v_buf,
907 dummy,
908 &self.rope.cos,
909 &self.rope.sin,
910 &mut self.scratch.v_head_major,
911 tokens,
912 nkv,
913 hd,
914 pos_offset,
915 eps,
916 0,
917 );
918 }
919 B::kv_cache_append_head_major(
920 ctx,
921 &mut cache.k,
922 &mut cache.v,
923 cache.len,
924 cache.capacity,
925 &self.scratch.k_head_major,
926 &self.scratch.v_head_major,
927 tokens,
928 nkv,
929 hd,
930 );
931 }
932 cache.len += tokens;
933 let kv_len = cache.len;
934 let kv_stride = cache.capacity;
935
936 if cache.block_size > 0 {
937 // Paged decode: read from the shared pool via block_table.
938 let bt = cache
939 .block_table
940 .as_ref()
941 .expect("paged cache missing block_table");
942 let cl_buf = cache
943 .context_lens
944 .as_mut()
945 .expect("paged cache missing context_lens");
946 let num_blocks_per_seq = cache.capacity / cache.block_size;
947 let (pool_k_ptr, pool_v_ptr) =
948 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
949 // SAFETY: see paged_pool_ptr setup above.
950 let pool_k = unsafe { &*pool_k_ptr };
951 let pool_v = unsafe { &*pool_v_ptr };
952 let final_kv_len = cache.len as u32;
953 B::write_u32(ctx, cl_buf, &[final_kv_len]);
954 B::paged_decode_attention(
955 ctx,
956 &self.scratch.q_head_major,
957 pool_k,
958 pool_v,
959 &mut self.scratch.attn_head_major_out,
960 bt,
961 cl_buf,
962 1, // num_seqs (single-seq m=1 path)
963 nh,
964 nkv,
965 hd,
966 cache.block_size,
967 num_blocks_per_seq,
968 tokens,
969 )
970 .expect("paged_decode_attention");
971 let _ = kv_stride; // consumed by contig path only
972 } else {
973 let attn_cfg = ferrum_kernels::backend::AttnConfig {
974 num_heads: nh,
975 num_kv_heads: nkv,
976 head_dim: hd,
977 causal: true,
978 scale: 1.0 / (hd as f32).sqrt(),
979 kv_seq_stride: kv_stride,
980 sliding_window: cfg_base.sliding_window,
981 };
982 B::flash_attention(
983 ctx,
984 &self.scratch.q_head_major,
985 &cache.k,
986 &cache.v,
987 &mut self.scratch.attn_head_major_out,
988 1,
989 tokens,
990 kv_len,
991 pos_offset,
992 &attn_cfg,
993 );
994 }
995
996 if let Some(t0) = attn_t0 {
997 B::sync(ctx);
998 ATTN_TIME_US.fetch_add(
999 t0.elapsed().as_micros() as u64,
1000 std::sync::atomic::Ordering::Relaxed,
1001 );
1002 ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1003 }
1004
1005 // 7. transpose head-major → token-major.
1006 //
1007 // For tokens=1 the two layouts are byte-identical: both
1008 // collapse to the flat [heads * head_dim] vector at offset
1009 // `head*hd + d`. Skip the dispatch and point o_proj at
1010 // attn_head_major_out directly. Saves 1 dispatch per layer
1011 // (×48 = 48 dispatches per decode token) on Qwen3-30B-A3B.
1012 let attn_token_major = if tokens == 1 {
1013 &self.scratch.attn_head_major_out
1014 } else {
1015 B::transpose_head_to_token(
1016 ctx,
1017 &self.scratch.attn_head_major_out,
1018 &mut self.scratch.attn_flat,
1019 tokens,
1020 nh,
1021 hd,
1022 );
1023 &self.scratch.attn_flat
1024 };
1025
1026 // 8. O-proj.
1027 attn_layer
1028 .o_proj
1029 .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1030
1031 // 9. fused residual-add + post-attention RMSNorm.
1032 B::fused_add_rms_norm(
1033 ctx,
1034 residual,
1035 &self.scratch.o_proj_out,
1036 &attn_layer.post_ln_w,
1037 eps,
1038 &mut self.scratch.norm_out,
1039 tokens,
1040 h,
1041 );
1042
1043 // ── MoE FFN block ────────────────────────────────────────────
1044 let moe_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1045 B::sync(ctx);
1046 Some(std::time::Instant::now())
1047 } else {
1048 None
1049 };
1050
1051 // 10. Router gemv: norm_out [tokens, hidden] → router_logits [tokens, num_experts]
1052 moe_layer.router.forward(
1053 ctx,
1054 &self.scratch.norm_out,
1055 &mut self.scratch.router_logits,
1056 tokens,
1057 );
1058
1059 // 11. Per-(token, expert) MLP dispatch + weighted combine.
1060 //
1061 // Two paths:
1062 // - **Batched fast path** (decode m=1, all stacked variants
1063 // present): single `gemv_quant_moe_id` dispatch covers all
1064 // 8 selected expert × 1 token gate gemvs in parallel; same
1065 // for up and down. Cuts per-layer expert dispatches from
1066 // ~32 (8 × 4 ops/pair) to 4 (gate + up + silu + down + 1 acc).
1067 // Routes Qwen3-30B-A3B decode close to llama.cpp's
1068 // `kernel_mul_mm_id`.
1069 // - **Per-(token, expert) fallback** via `moe_forward` —
1070 // used for prefill (m > 1), or when the backend doesn't
1071 // populate stacked variants (CPU, synthetic-MoE tests).
1072 let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
1073 && moe_layer.experts.up_stacked.is_some()
1074 && moe_layer.experts.down_stacked.is_some();
1075
1076 // Fast path for decode (tokens=1): the stacked decode impl
1077 // writes the weighted-sum result *directly* into `residual` via
1078 // `weighted_sum_residual_stacked`, skipping the moe_out scratch
1079 // and the trailing `add_inplace`. Saves 1 dispatch per layer.
1080 // Prefill (m>1) and the per-expert fallback still go through
1081 // moe_out + add_inplace.
1082 let decode_fast_path = stacked_path_available && tokens == 1;
1083 // Cross-layer fusion: when on the decode fast path AND there is
1084 // a next layer, fold its leading rms_norm into this layer's
1085 // tail (`weighted_sum_residual_norm_stacked`). Returns whether
1086 // the fusion ran so the caller can signal the next layer to
1087 // skip its standalone rms_norm.
1088 let did_norm_fusion = decode_fast_path && next_layer_idx.is_some();
1089
1090 if stacked_path_available {
1091 if tokens > 1 {
1092 // Prefill: one batched 2-D mul_mm_id covers all
1093 // (token, expert) pairs in parallel.
1094 self.moe_forward_batched_prefill(ctx, li, tokens)?;
1095 } else {
1096 // Decode m=1: dedicated per-token path that fuses
1097 // residual-add into the final weighted-sum, and
1098 // optionally folds the next layer's rms_norm in too.
1099 self.moe_forward_stacked(ctx, li, tokens, residual, next_layer_idx)?;
1100 }
1101 } else {
1102 moe_forward::<B>(
1103 ctx,
1104 &self.scratch.norm_out,
1105 &self.scratch.router_logits,
1106 &mut self.scratch.moe_out,
1107 tokens,
1108 h,
1109 self.cfg.expert_intermediate_size,
1110 self.cfg.num_experts,
1111 self.cfg.num_experts_per_tok,
1112 self.cfg.norm_topk_prob,
1113 &moe_layer.experts,
1114 &mut self.scratch.x_single,
1115 &mut self.scratch.acc_buf,
1116 &mut self.scratch.gate_up_buf,
1117 &mut self.scratch.silu_buf,
1118 &mut self.scratch.down_buf,
1119 &self.scratch.zero_hidden,
1120 )?;
1121 }
1122
1123 // 12. residual += moe_out (skipped on decode fast path — already
1124 // accumulated by `weighted_sum_residual_stacked`).
1125 if !decode_fast_path {
1126 B::add_inplace(ctx, residual, &self.scratch.moe_out, tokens * h);
1127 }
1128
1129 if let Some(t0) = moe_t0 {
1130 B::sync(ctx);
1131 MOE_TIME_US.fetch_add(
1132 t0.elapsed().as_micros() as u64,
1133 std::sync::atomic::Ordering::Relaxed,
1134 );
1135 MOE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1136 }
1137
1138 Ok(did_norm_fusion)
1139 }
1140
1141 fn moe_forward_stacked(
1142 &mut self,
1143 ctx: &mut B::Context,
1144 li: usize,
1145 tokens: usize,
1146 residual: &mut B::Buffer,
1147 next_layer_idx: Option<usize>,
1148 ) -> Result<()> {
1149 let cfg = &self.cfg;
1150 // `next_norm_w` is the next layer's `attn_layer.input_ln_w`.
1151 // We can't borrow `self.attn_layers[idx]` and pass &mut
1152 // self.scratch to the impl simultaneously, so collect the raw
1153 // pointer here. Safety: forward_layer holds &mut self for the
1154 // call; the borrow scopes are fully sequential.
1155 let next_norm_w_ptr: Option<*const B::Buffer> =
1156 next_layer_idx.map(|idx| &self.attn_layers[idx].input_ln_w as *const _);
1157 // SAFETY: pointer dereference is valid because:
1158 // * The buffer lives in `self.attn_layers[idx]` which we
1159 // borrowed immutably to take the pointer. We do not mutate
1160 // `self.attn_layers` while `next_norm_w_ptr` is in use.
1161 // * `&mut self.scratch` and `&self.moe_layers[li]` are disjoint
1162 // fields from `self.attn_layers` so this is safe.
1163 let next_norm_w: Option<&B::Buffer> = next_norm_w_ptr.map(|p| unsafe { &*p });
1164 moe_forward_stacked_decode_impl::<B>(
1165 ctx,
1166 &self.moe_layers[li],
1167 &mut self.scratch,
1168 cfg.base.hidden_size,
1169 cfg.expert_intermediate_size,
1170 cfg.num_experts_per_tok,
1171 cfg.num_experts,
1172 cfg.norm_topk_prob,
1173 tokens,
1174 residual,
1175 next_norm_w,
1176 cfg.base.rms_norm_eps,
1177 )
1178 }
1179
1180 fn moe_forward_batched_prefill(
1181 &mut self,
1182 ctx: &mut B::Context,
1183 li: usize,
1184 tokens: usize,
1185 ) -> Result<()> {
1186 let cfg = &self.cfg;
1187 moe_forward_batched_prefill_impl::<B>(
1188 ctx,
1189 &self.moe_layers[li],
1190 &mut self.scratch,
1191 cfg.base.hidden_size,
1192 cfg.expert_intermediate_size,
1193 cfg.num_experts_per_tok,
1194 cfg.num_experts,
1195 cfg.norm_topk_prob,
1196 tokens,
1197 )
1198 }
1199
1200 /// Prefill: process `tokens` prompt tokens, return last-token logits.
1201 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1202 let seq_len = tokens.len();
1203 assert!(seq_len > 0);
1204 self.ensure_scratch(seq_len);
1205 self.ensure_kv(cache_id);
1206
1207 let pos_offset = self
1208 .kv_caches
1209 .get(cache_id)
1210 .and_then(|layers| layers.first())
1211 .map(|c| c.len)
1212 .unwrap_or(0);
1213
1214 let h = self.cfg.base.hidden_size;
1215 let vocab = self.cfg.base.vocab_size;
1216 let mut ctx = B::new_context();
1217
1218 // FERRUM_DECODE_OP_PROFILE doubles as the prefill-profile gate
1219 // for Qwen3-MoE: when set, dump (attn-us, moe-us, total-us) at
1220 // the end of prefill so we can attribute the prefill bottleneck
1221 // between attention and MoE.
1222 let prefill_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1223 B::sync(&mut ctx);
1224 for c in [
1225 &ATTN_TIME_US,
1226 &ATTN_CALLS,
1227 &MOE_TIME_US,
1228 &MOE_CALLS,
1229 &MOE_PREFILL_HOST_TOPK_US,
1230 &MOE_PREFILL_HOST_TOPK_CALLS,
1231 &MOE_PREFILL_GATE_US,
1232 &MOE_PREFILL_GATE_CALLS,
1233 &MOE_PREFILL_UP_US,
1234 &MOE_PREFILL_UP_CALLS,
1235 &MOE_PREFILL_SILU_US,
1236 &MOE_PREFILL_SILU_CALLS,
1237 &MOE_PREFILL_DOWN_US,
1238 &MOE_PREFILL_DOWN_CALLS,
1239 &MOE_PREFILL_WSUM_US,
1240 &MOE_PREFILL_WSUM_CALLS,
1241 ] {
1242 c.store(0, std::sync::atomic::Ordering::Relaxed);
1243 }
1244 Some(std::time::Instant::now())
1245 } else {
1246 None
1247 };
1248
1249 let mut residual = self
1250 .scratch
1251 .residual
1252 .take()
1253 .expect("scratch residual missing (previous call didn't restore)");
1254 B::embedding_lookup(&mut ctx, &self.embed, tokens, &mut residual, h);
1255
1256 // For prefill (seq_len > 1) the cross-layer norm fusion does
1257 // not apply (it lives on the decode fast path). We still pass
1258 // `next_layer_idx = None` so forward_layer emits the regular
1259 // tail.
1260 let mut prev_did_norm_fusion = false;
1261 let num_layers = self.cfg.base.num_layers;
1262 for li in 0..num_layers {
1263 let next_layer_idx = if li + 1 < num_layers {
1264 Some(li + 1)
1265 } else {
1266 None
1267 };
1268 prev_did_norm_fusion = self
1269 .forward_layer(
1270 &mut ctx,
1271 li,
1272 cache_id,
1273 &mut residual,
1274 pos_offset,
1275 seq_len,
1276 next_layer_idx,
1277 prev_did_norm_fusion,
1278 )
1279 .expect("forward_layer");
1280 }
1281
1282 // Last-token slice → final RMSNorm → lm_head.
1283 B::copy_slice(
1284 &mut ctx,
1285 &residual,
1286 (seq_len - 1) * h,
1287 &mut self.scratch.last_hidden,
1288 0,
1289 h,
1290 );
1291 B::rms_norm(
1292 &mut ctx,
1293 &self.scratch.last_hidden,
1294 &self.final_norm_w,
1295 self.cfg.base.rms_norm_eps,
1296 &mut self.scratch.last_normed,
1297 1,
1298 h,
1299 );
1300 self.lm_head.forward(
1301 &mut ctx,
1302 &self.scratch.last_normed,
1303 &mut self.scratch.logits,
1304 1,
1305 );
1306
1307 B::sync(&mut ctx);
1308 if let Some(t0) = prefill_t0 {
1309 let total_us = t0.elapsed().as_micros() as u64;
1310 let attn_us = ATTN_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
1311 let attn_n = ATTN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1312 let moe_us = MOE_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
1313 let moe_n = MOE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1314 let other_us = total_us.saturating_sub(attn_us).saturating_sub(moe_us);
1315 eprintln!(
1316 "[prefill-profile] tokens={seq_len} total={} ms ({:.0} t/s)",
1317 total_us / 1000,
1318 seq_len as f64 * 1e6 / total_us as f64
1319 );
1320 let bucket = |label: &str, n: u64, us: u64| {
1321 if n > 0 {
1322 eprintln!(
1323 " {label:>6}: {:7} ms ({:5.1}%) over {n:4} calls",
1324 us / 1000,
1325 us as f64 * 100.0 / total_us as f64
1326 );
1327 }
1328 };
1329 bucket("attn", attn_n, attn_us);
1330 bucket("moe", moe_n, moe_us);
1331 bucket("other", 1, other_us);
1332 // MoE sub-stages — show as % of total prefill time so they
1333 // reconcile against the `moe` bucket above.
1334 let host_us = MOE_PREFILL_HOST_TOPK_US.load(std::sync::atomic::Ordering::Relaxed);
1335 let gate_us = MOE_PREFILL_GATE_US.load(std::sync::atomic::Ordering::Relaxed);
1336 let up_us = MOE_PREFILL_UP_US.load(std::sync::atomic::Ordering::Relaxed);
1337 let silu_us = MOE_PREFILL_SILU_US.load(std::sync::atomic::Ordering::Relaxed);
1338 let down_us = MOE_PREFILL_DOWN_US.load(std::sync::atomic::Ordering::Relaxed);
1339 let wsum_us = MOE_PREFILL_WSUM_US.load(std::sync::atomic::Ordering::Relaxed);
1340 let host_n = MOE_PREFILL_HOST_TOPK_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1341 let gate_n = MOE_PREFILL_GATE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1342 let up_n = MOE_PREFILL_UP_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1343 let silu_n = MOE_PREFILL_SILU_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1344 let down_n = MOE_PREFILL_DOWN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1345 let wsum_n = MOE_PREFILL_WSUM_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1346 bucket(" host", host_n, host_us);
1347 bucket(" gate", gate_n, gate_us);
1348 bucket(" up", up_n, up_us);
1349 bucket(" silu", silu_n, silu_us);
1350 bucket(" down", down_n, down_us);
1351 bucket(" wsum", wsum_n, wsum_us);
1352 }
1353 self.scratch.residual = Some(residual);
1354 B::to_vec(&self.scratch.logits, vocab)
1355 }
1356
1357 /// Decode: 1 token at position `pos`, return next-step logits.
1358 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1359 self.ensure_scratch(1);
1360 self.ensure_kv(cache_id);
1361
1362 let h = self.cfg.base.hidden_size;
1363 let vocab = self.cfg.base.vocab_size;
1364 let mut ctx = B::new_context();
1365
1366 let decode_t0 = if std::env::var("FERRUM_MOE_PROFILE").is_ok() {
1367 Some(std::time::Instant::now())
1368 } else {
1369 None
1370 };
1371
1372 // FERRUM_DECODE_OP_PROFILE gates the per-stage breakdown emitted
1373 // at the bottom of every decode token. Reuses the same atomic
1374 // counters that `forward_layer` already populates (ATTN_TIME_US,
1375 // MOE_TIME_US — drained here per-token instead of per-prefill).
1376 let stage_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1377 B::sync(&mut ctx);
1378 for c in [
1379 &ATTN_TIME_US,
1380 &ATTN_CALLS,
1381 &MOE_TIME_US,
1382 &MOE_CALLS,
1383 &DEC_ROUTE_US,
1384 &DEC_GATE_US,
1385 &DEC_UP_US,
1386 &DEC_SILU_US,
1387 &DEC_DOWN_US,
1388 &DEC_WSUM_US,
1389 &DEC_EMBED_US,
1390 &DEC_FINAL_NORM_US,
1391 &DEC_LM_HEAD_US,
1392 ] {
1393 c.store(0, std::sync::atomic::Ordering::Relaxed);
1394 }
1395 Some(std::time::Instant::now())
1396 } else {
1397 None
1398 };
1399 let prof = stage_t0.is_some();
1400 let mark = |ctx: &mut B::Context, c: &AtomicU64, t0: std::time::Instant| {
1401 if prof {
1402 B::sync(ctx);
1403 c.fetch_add(
1404 t0.elapsed().as_micros() as u64,
1405 std::sync::atomic::Ordering::Relaxed,
1406 );
1407 }
1408 };
1409 let mt0 = std::time::Instant::now();
1410
1411 let mut residual = self
1412 .scratch
1413 .residual
1414 .take()
1415 .expect("scratch residual missing (previous call didn't restore)");
1416 let t0 = std::time::Instant::now();
1417 B::embedding_lookup(&mut ctx, &self.embed, &[token], &mut residual, h);
1418 mark(&mut ctx, &DEC_EMBED_US, t0);
1419 let _ = mt0; // silence if unused on non-profile builds
1420
1421 // Cross-layer rms_norm fusion: layer L's MoE tail folds the
1422 // next layer's leading rms_norm into its weighted-sum-residual
1423 // when the decode fast path applies. The flag carries forward.
1424 let mut prev_did_norm_fusion = false;
1425 let num_layers = self.cfg.base.num_layers;
1426 for li in 0..num_layers {
1427 let next_layer_idx = if li + 1 < num_layers {
1428 Some(li + 1)
1429 } else {
1430 None
1431 };
1432 prev_did_norm_fusion = self
1433 .forward_layer(
1434 &mut ctx,
1435 li,
1436 cache_id,
1437 &mut residual,
1438 pos as usize,
1439 1,
1440 next_layer_idx,
1441 prev_did_norm_fusion,
1442 )
1443 .expect("forward_layer");
1444 }
1445
1446 let t0 = std::time::Instant::now();
1447 B::rms_norm(
1448 &mut ctx,
1449 &residual,
1450 &self.final_norm_w,
1451 self.cfg.base.rms_norm_eps,
1452 &mut self.scratch.last_normed,
1453 1,
1454 h,
1455 );
1456 mark(&mut ctx, &DEC_FINAL_NORM_US, t0);
1457
1458 let t0 = std::time::Instant::now();
1459 self.lm_head.forward(
1460 &mut ctx,
1461 &self.scratch.last_normed,
1462 &mut self.scratch.logits,
1463 1,
1464 );
1465 mark(&mut ctx, &DEC_LM_HEAD_US, t0);
1466
1467 B::sync(&mut ctx);
1468 self.scratch.residual = Some(residual);
1469
1470 // FERRUM_DECODE_OP_PROFILE: per-token decode breakdown.
1471 if let Some(t0) = stage_t0 {
1472 use std::sync::atomic::Ordering;
1473 let total_us = t0.elapsed().as_micros() as u64;
1474 let attn_us = ATTN_TIME_US.swap(0, Ordering::Relaxed);
1475 let moe_us = MOE_TIME_US.swap(0, Ordering::Relaxed);
1476 let route = DEC_ROUTE_US.swap(0, Ordering::Relaxed);
1477 let gate = DEC_GATE_US.swap(0, Ordering::Relaxed);
1478 let up = DEC_UP_US.swap(0, Ordering::Relaxed);
1479 let silu = DEC_SILU_US.swap(0, Ordering::Relaxed);
1480 let down = DEC_DOWN_US.swap(0, Ordering::Relaxed);
1481 let wsum = DEC_WSUM_US.swap(0, Ordering::Relaxed);
1482 let embed = DEC_EMBED_US.swap(0, Ordering::Relaxed);
1483 let fnorm = DEC_FINAL_NORM_US.swap(0, Ordering::Relaxed);
1484 let lmhead = DEC_LM_HEAD_US.swap(0, Ordering::Relaxed);
1485 let other = total_us.saturating_sub(attn_us + moe_us + embed + fnorm + lmhead);
1486 let pct = |us: u64| -> f64 {
1487 if total_us == 0 {
1488 0.0
1489 } else {
1490 100.0 * us as f64 / total_us as f64
1491 }
1492 };
1493 eprintln!(
1494 "[decode-prof] total={} ms | attn={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | embed={} fnorm={} lmhead={} other={} ({:.1}%)",
1495 total_us / 1000,
1496 attn_us / 1000, pct(attn_us),
1497 moe_us / 1000, pct(moe_us),
1498 route / 1000, gate / 1000, up / 1000, silu / 1000, down / 1000, wsum / 1000,
1499 embed / 1000, fnorm / 1000, lmhead / 1000,
1500 other / 1000, pct(other),
1501 );
1502 }
1503
1504 // Drain MoE per-op counters every decode step. The counters
1505 // accumulate across all 48 layers; printing per-step gives a
1506 // per-token breakdown.
1507 if let Some(t0) = decode_t0 {
1508 use crate::moe::dispatch::*;
1509 use std::sync::atomic::Ordering;
1510 let total_us = t0.elapsed().as_micros() as u64;
1511 let sync_us = MOE_SYNC_US.swap(0, Ordering::Relaxed);
1512 let sync_n = MOE_SYNC_CALLS.swap(0, Ordering::Relaxed);
1513 let topk_us = MOE_HOST_TOPK_US.swap(0, Ordering::Relaxed);
1514 let topk_n = MOE_HOST_TOPK_CALLS.swap(0, Ordering::Relaxed);
1515 let gu_us = MOE_GEMV_GATE_UP_US.swap(0, Ordering::Relaxed);
1516 let gu_n = MOE_GEMV_GATE_UP_CALLS.swap(0, Ordering::Relaxed);
1517 let silu_us = MOE_SILU_US.swap(0, Ordering::Relaxed);
1518 let silu_n = MOE_SILU_CALLS.swap(0, Ordering::Relaxed);
1519 let dn_us = MOE_GEMV_DOWN_US.swap(0, Ordering::Relaxed);
1520 let dn_n = MOE_GEMV_DOWN_CALLS.swap(0, Ordering::Relaxed);
1521 let sa_us = MOE_SCALED_ADD_US.swap(0, Ordering::Relaxed);
1522 let sa_n = MOE_SCALED_ADD_CALLS.swap(0, Ordering::Relaxed);
1523 let cp_us = MOE_COPY_US.swap(0, Ordering::Relaxed);
1524 let cp_n = MOE_COPY_CALLS.swap(0, Ordering::Relaxed);
1525 eprintln!(
1526 "[moe-prof] decode total={} ms | sync={} ms ({}x) | host_topk={} ms ({}x) | gate_up={} ms ({}x) | silu={} ms ({}x) | down={} ms ({}x) | scaled_add={} ms ({}x) | copy={} ms ({}x)",
1527 total_us / 1000,
1528 sync_us / 1000, sync_n,
1529 topk_us / 1000, topk_n,
1530 gu_us / 1000, gu_n,
1531 silu_us / 1000, silu_n,
1532 dn_us / 1000, dn_n,
1533 sa_us / 1000, sa_n,
1534 cp_us / 1000, cp_n,
1535 );
1536 }
1537
1538 B::to_vec(&self.scratch.logits, vocab)
1539 }
1540
1541 /// Multi-sequence batched decode (Phase 4b for MoE).
1542 ///
1543 /// Mirrors `LlamaFamilyModel::decode_batch_internal` but adapted to
1544 /// the MoE forward. The wins come from running the GEMM-heavy ops
1545 /// (qkv_proj, o_proj, router, MoE expert mul_mm_id, lm_head) at
1546 /// m=M, even though attention stays a per-item loop because
1547 /// Qwen3-MoE uses contiguous KV — no paged path here.
1548 ///
1549 /// Cross-layer rms_norm fusion (the `weighted_sum_residual_norm_stacked`
1550 /// fast path) is disabled in batched mode: the prefill MoE path
1551 /// (`moe_forward_batched_prefill_impl`) writes to `moe_out` and we
1552 /// add to residual explicitly. Each layer's leading rms_norm runs
1553 /// at m=M, which is one fused dispatch on M rows — cheap.
1554 pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1555 let m = batch.len();
1556 if m == 0 {
1557 return Vec::new();
1558 }
1559 if m == 1 {
1560 let (cid, tok, pos) = &batch[0];
1561 return vec![self.decode_internal(cid, *tok, *pos)];
1562 }
1563
1564 let prof_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1565 Some(std::time::Instant::now())
1566 } else {
1567 None
1568 };
1569
1570 for (cid, _, _) in batch {
1571 self.ensure_kv(cid);
1572 }
1573 self.ensure_scratch(m);
1574 self.scratch.enable_batched_decode_scratch(&self.cfg);
1575
1576 let h = self.cfg.base.hidden_size;
1577 let vocab = self.cfg.base.vocab_size;
1578 let mut ctx = B::new_context();
1579
1580 // 0. Embed all M tokens into residual [M, H]
1581 let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1582 let mut residual = self
1583 .scratch
1584 .residual
1585 .take()
1586 .expect("scratch residual missing (previous call didn't restore)");
1587 B::embedding_lookup(&mut ctx, &self.embed, &tokens, &mut residual, h);
1588
1589 // 1..num_layers: batched forward for each layer
1590 for li in 0..self.cfg.base.num_layers {
1591 self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m)
1592 .expect("forward_layer_batched_decode");
1593 }
1594
1595 // Final RMSNorm on [M, H] → norm_out [M, H]
1596 B::rms_norm(
1597 &mut ctx,
1598 &residual,
1599 &self.final_norm_w,
1600 self.cfg.base.rms_norm_eps,
1601 &mut self.scratch.norm_out,
1602 m,
1603 h,
1604 );
1605
1606 // LM head with m=M → batch_logits [M, vocab]
1607 self.lm_head.forward(
1608 &mut ctx,
1609 &self.scratch.norm_out,
1610 &mut self.scratch.batch_logits,
1611 m,
1612 );
1613
1614 B::sync(&mut ctx);
1615 self.scratch.residual = Some(residual);
1616
1617 let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
1618
1619 // Profile dump (one decode_batch_internal call = one decode step
1620 // covering all m tokens).
1621 if let Some(t0) = prof_t0 {
1622 use std::sync::atomic::Ordering;
1623 let total_us = t0.elapsed().as_micros() as u64;
1624 let dense = BD_DENSE_US.swap(0, Ordering::Relaxed);
1625 let attn = BD_ATTN_PERITEM_US.swap(0, Ordering::Relaxed);
1626 let moe = BD_MOE_US.swap(0, Ordering::Relaxed);
1627 let layers = BD_LAYER_CALLS.swap(0, Ordering::Relaxed);
1628 let other = total_us.saturating_sub(dense + attn + moe);
1629 let pct = |us: u64| -> f64 {
1630 if total_us == 0 {
1631 0.0
1632 } else {
1633 100.0 * us as f64 / total_us as f64
1634 }
1635 };
1636 // MoE sub-stage breakdown — meaningful when
1637 // moe_forward_batched_decode_impl was used.
1638 let moe_route = MOE_BATCHED_DECODE_ROUTE_US.swap(0, Ordering::Relaxed);
1639 let moe_gate = MOE_BATCHED_DECODE_GATE_US.swap(0, Ordering::Relaxed);
1640 let moe_up = MOE_BATCHED_DECODE_UP_US.swap(0, Ordering::Relaxed);
1641 let moe_silu = MOE_BATCHED_DECODE_SILU_US.swap(0, Ordering::Relaxed);
1642 let moe_down = MOE_BATCHED_DECODE_DOWN_US.swap(0, Ordering::Relaxed);
1643 let moe_wsum = MOE_BATCHED_DECODE_WSUM_US.swap(0, Ordering::Relaxed);
1644 eprintln!(
1645 "[batched-decode-prof] m={} layers={} total={} ms | dense={} ({:.1}%) | attn_peritem={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | other={} ({:.1}%)",
1646 m, layers, total_us / 1000,
1647 dense / 1000, pct(dense),
1648 attn / 1000, pct(attn),
1649 moe / 1000, pct(moe),
1650 moe_route / 1000, moe_gate / 1000, moe_up / 1000,
1651 moe_silu / 1000, moe_down / 1000, moe_wsum / 1000,
1652 other / 1000, pct(other),
1653 );
1654 }
1655
1656 (0..m)
1657 .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
1658 .collect()
1659 }
1660
1661 /// One transformer layer over M items: GEMMs at m=M, per-item
1662 /// attention loop, MoE FFN at m=M via the prefill batched path.
1663 /// Mirrors `LlamaFamilyModel::forward_layer_batched_decode` minus
1664 /// the paged branch.
1665 fn forward_layer_batched_decode(
1666 &mut self,
1667 ctx: &mut B::Context,
1668 li: usize,
1669 batch: &[(String, u32, u32)],
1670 residual: &mut B::Buffer,
1671 m: usize,
1672 ) -> Result<()> {
1673 let cfg_base = &self.cfg.base;
1674 let h = cfg_base.hidden_size;
1675 let nh = cfg_base.num_heads;
1676 let nkv = cfg_base.num_kv_heads;
1677 let hd = cfg_base.head_dim;
1678 let eps = cfg_base.rms_norm_eps;
1679 let q_dim = nh * hd;
1680 let kv_dim = nkv * hd;
1681
1682 let attn_layer = &self.attn_layers[li];
1683 let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
1684 let dummy_w = &attn_layer.input_ln_w;
1685 let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy_w);
1686 let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy_w);
1687
1688 let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
1689 let stage_t0 = || -> Option<std::time::Instant> {
1690 if prof {
1691 Some(std::time::Instant::now())
1692 } else {
1693 None
1694 }
1695 };
1696 let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
1697 if let Some(t) = t0 {
1698 B::sync(ctx);
1699 c.fetch_add(
1700 t.elapsed().as_micros() as u64,
1701 std::sync::atomic::Ordering::Relaxed,
1702 );
1703 }
1704 };
1705 if prof {
1706 BD_LAYER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1707 }
1708
1709 let dense_t0 = stage_t0();
1710
1711 // 1. rms_norm [M, H] → norm_out
1712 B::rms_norm(
1713 ctx,
1714 residual,
1715 &attn_layer.input_ln_w,
1716 eps,
1717 &mut self.scratch.norm_out,
1718 m,
1719 h,
1720 );
1721
1722 // 2. qkv_proj GEMM at m=M: norm_out [M, H] → qkv_out [M, QKV]
1723 attn_layer
1724 .qkv_proj
1725 .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
1726
1727 // ── Paged batched attention path ───────────────────────────────
1728 //
1729 // Mirrors LlamaFamilyModel's Phase 4b paged batched-decode. When
1730 // `FERRUM_METAL_PAGED_KV=1` was set at ensure_kv time, each
1731 // cache_id has paged metadata (block_table + context_lens) and
1732 // K/V live in the shared `paged_pools[layer]` pool. This path:
1733 // 1. m × `split_qkv_norm_rope_into_paged_cache` writes K/V into
1734 // the pool at each item's allocated blocks AND fills
1735 // `paged_batch_q[i*q_dim ..]` with that item's head-major Q.
1736 // 2. Build `paged_batch_block_tables [m, max_blocks_per_seq]`
1737 // and `paged_batch_context_lens [m]` host-side, upload.
1738 // 3. ONE `paged_decode_attention(num_seqs=m)` call reads all m
1739 // sequences' K/V from the pool via per-seq block_tables,
1740 // writes outputs to `paged_batch_o [m, q_dim]`.
1741 // 4. Per-item copy_slice paged_batch_o[i] → attn_flat[i*q_dim].
1742 //
1743 // This is the structural fix for the c=16 attn_peritem cliff
1744 // (~55 ms / round of 16 sequential m=1 flash_attn + plumbing).
1745 let is_paged = self.paged_pools.is_some();
1746 if is_paged {
1747 stage_end(dense_t0, ctx, &BD_DENSE_US);
1748 let attn_t0 = stage_t0();
1749
1750 let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
1751 let block_size = 16; // matches PAGED_BLOCK_SIZE in ensure_kv
1752 let qkv_stride = q_dim + 2 * kv_dim;
1753
1754 // Step 1: per-item paged write. Read each item's qkv out of
1755 // the batched qkv_out buffer; write its head-major Q into
1756 // paged_batch_q[i*q_dim ..]; write its K/V into the pool at
1757 // its allocated blocks via block_table.
1758 let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
1759 let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
1760 let pool_ptr = {
1761 let pools = self.paged_pools.as_mut().unwrap();
1762 (
1763 &mut pools[li].0 as *mut B::Buffer,
1764 &mut pools[li].1 as *mut B::Buffer,
1765 )
1766 };
1767 // SAFETY: pools allocated-once, see paged_pools field comment.
1768 let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
1769 for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1770 let pos_i = *pos as usize;
1771 let caches = self
1772 .kv_caches
1773 .get(cache_id)
1774 .expect("paged batched: cache not present");
1775 let cache = &caches[li];
1776 let bt = cache
1777 .block_table
1778 .as_ref()
1779 .expect("paged batched: block_table missing");
1780 let cache_len_before = cache.len;
1781 let bt_ptr = bt as *const B::Buffer;
1782 // SAFETY: bt is read-only during the dispatch; we don't
1783 // touch self.kv_caches between this raw deref and the
1784 // call below.
1785 let bt_safe: &B::Buffer = unsafe { &*bt_ptr };
1786 B::split_qkv_norm_rope_into_paged_cache(
1787 ctx,
1788 &self.scratch.qkv_out,
1789 (i as u64) * qkv_stride_bytes,
1790 q_norm_w,
1791 k_norm_w,
1792 &self.rope.cos,
1793 &self.rope.sin,
1794 self.scratch
1795 .paged_batch_q
1796 .as_mut()
1797 .expect("paged_batch_q missing"),
1798 (i as u64) * q_head_major_size_bytes,
1799 pool_k,
1800 pool_v,
1801 bt_safe,
1802 1,
1803 nh,
1804 nkv,
1805 hd,
1806 pos_i,
1807 eps,
1808 qk_mode,
1809 cache_len_before,
1810 block_size,
1811 max_blocks_per_seq,
1812 )
1813 .expect("split_qkv_norm_rope_into_paged_cache (batched)");
1814 }
1815
1816 // Step 2: bump cache.len and stack block_tables + context_lens
1817 // host-side, then upload to device scratch.
1818 let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
1819 let mut stacked_cl: Vec<u32> = vec![0u32; m];
1820 for (i, (cache_id, _, _)) in batch.iter().enumerate() {
1821 let caches = self
1822 .kv_caches
1823 .get_mut(cache_id)
1824 .expect("paged batched: cache not present");
1825 let cache = &mut caches[li];
1826 cache.len += 1;
1827 let len = cache.len as u32;
1828 stacked_cl[i] = len;
1829 let blocks = &cache.paged_block_indices;
1830 let n_to_copy = blocks.len().min(max_blocks_per_seq);
1831 stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
1832 .copy_from_slice(&blocks[..n_to_copy]);
1833 }
1834 let bt_buf = self
1835 .scratch
1836 .paged_batch_block_tables
1837 .as_mut()
1838 .expect("paged_batch_block_tables missing");
1839 B::write_u32(ctx, bt_buf, &stacked_bt);
1840 let cl_buf = self
1841 .scratch
1842 .paged_batch_context_lens
1843 .as_mut()
1844 .expect("paged_batch_context_lens missing");
1845 B::write_u32(ctx, cl_buf, &stacked_cl);
1846
1847 // Step 3: one batched paged_decode_attention(num_seqs=m).
1848 let bt_ptr =
1849 self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
1850 let cl_ptr =
1851 self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
1852 let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
1853 let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
1854 // SAFETY: scratch buffers are not aliased; we hold &mut self
1855 // through this entire block.
1856 let bt_safe = unsafe { &*bt_ptr };
1857 let cl_safe = unsafe { &*cl_ptr };
1858 let q_safe = unsafe { &*q_ptr };
1859 let o_safe = unsafe { &mut *o_ptr };
1860 B::paged_decode_attention(
1861 ctx,
1862 q_safe,
1863 pool_k,
1864 pool_v,
1865 o_safe,
1866 bt_safe,
1867 cl_safe,
1868 m,
1869 nh,
1870 nkv,
1871 hd,
1872 block_size,
1873 max_blocks_per_seq,
1874 1, // q_len
1875 )
1876 .expect("paged batched decode");
1877
1878 // Step 4: per-item copy paged_batch_o[i] → attn_flat[i*q_dim].
1879 for i in 0..m {
1880 B::copy_slice(
1881 ctx,
1882 self.scratch.paged_batch_o.as_ref().unwrap(),
1883 i * q_dim,
1884 &mut self.scratch.attn_flat,
1885 i * q_dim,
1886 q_dim,
1887 );
1888 }
1889
1890 stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
1891 } else {
1892 // 3. split_qkv [M, QKV] → q_buf [M, Q], k_buf [M, KV], v_buf [M, KV]
1893 B::split_qkv(
1894 ctx,
1895 &self.scratch.qkv_out,
1896 &mut self.scratch.q_buf,
1897 &mut self.scratch.k_buf,
1898 &mut self.scratch.v_buf,
1899 m,
1900 q_dim,
1901 kv_dim,
1902 );
1903
1904 // 4-6. Per-item loop: rope + kv_append + attention.
1905 // Each item has its own cache_id + pos + kv_len.
1906 let q_single = self
1907 .scratch
1908 .q_single
1909 .as_ref()
1910 .expect("q_single missing — enable_batched_decode_scratch not called")
1911 as *const B::Buffer;
1912 let k_single =
1913 self.scratch.k_single.as_ref().expect("k_single missing") as *const B::Buffer;
1914 let v_single =
1915 self.scratch.v_single.as_ref().expect("v_single missing") as *const B::Buffer;
1916 let q_hm_single =
1917 self.scratch
1918 .q_head_major_single
1919 .as_mut()
1920 .expect("q_head_major_single missing") as *mut B::Buffer;
1921 let k_hm_single =
1922 self.scratch
1923 .k_head_major_single
1924 .as_mut()
1925 .expect("k_head_major_single missing") as *mut B::Buffer;
1926 let v_hm_single =
1927 self.scratch
1928 .v_head_major_single
1929 .as_mut()
1930 .expect("v_head_major_single missing") as *mut B::Buffer;
1931 let attn_hm_single =
1932 self.scratch
1933 .attn_head_major_single
1934 .as_mut()
1935 .expect("attn_head_major_single missing") as *mut B::Buffer;
1936 // SAFETY: each Option holds a stable B::Buffer; we don't mutate
1937 // self.scratch in a way that would invalidate them inside the loop
1938 // (the kv_caches mutation is on a disjoint field).
1939
1940 // End of dense block (rms_norm + qkv_proj + split_qkv); start
1941 // per-item attention loop instrumentation.
1942 stage_end(dense_t0, ctx, &BD_DENSE_US);
1943 let attn_t0 = stage_t0();
1944
1945 for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1946 let pos_i = *pos as usize;
1947
1948 // SAFETY: borrows of disjoint scratch fields, see above.
1949 let q_single_ref = unsafe { &*q_single };
1950 let k_single_ref = unsafe { &*k_single };
1951 let v_single_ref = unsafe { &*v_single };
1952 let q_hm_single_mut = unsafe { &mut *q_hm_single };
1953 let k_hm_single_mut = unsafe { &mut *k_hm_single };
1954 let v_hm_single_mut = unsafe { &mut *v_hm_single };
1955 let attn_hm_single_mut = unsafe { &mut *attn_hm_single };
1956
1957 // Extract item i's Q/K/V slice from the batched buffers.
1958 B::copy_slice(
1959 ctx,
1960 &self.scratch.q_buf,
1961 i * q_dim,
1962 // copy_slice signature wants &mut for dst, but q_single
1963 // is shared; we need a *mut variant — since enable_*
1964 // gives us Option, we can do as_mut() here.
1965 self.scratch.q_single.as_mut().unwrap(),
1966 0,
1967 q_dim,
1968 );
1969 B::copy_slice(
1970 ctx,
1971 &self.scratch.k_buf,
1972 i * kv_dim,
1973 self.scratch.k_single.as_mut().unwrap(),
1974 0,
1975 kv_dim,
1976 );
1977 B::copy_slice(
1978 ctx,
1979 &self.scratch.v_buf,
1980 i * kv_dim,
1981 self.scratch.v_single.as_mut().unwrap(),
1982 0,
1983 kv_dim,
1984 );
1985
1986 // qk_norm_rope with tokens=1, per-item pos.
1987 B::qk_norm_rope(
1988 ctx,
1989 q_single_ref,
1990 q_norm_w,
1991 &self.rope.cos,
1992 &self.rope.sin,
1993 q_hm_single_mut,
1994 1,
1995 nh,
1996 hd,
1997 pos_i,
1998 eps,
1999 qk_mode,
2000 );
2001 B::qk_norm_rope(
2002 ctx,
2003 k_single_ref,
2004 k_norm_w,
2005 &self.rope.cos,
2006 &self.rope.sin,
2007 k_hm_single_mut,
2008 1,
2009 nkv,
2010 hd,
2011 pos_i,
2012 eps,
2013 qk_mode,
2014 );
2015 B::qk_norm_rope(
2016 ctx,
2017 v_single_ref,
2018 dummy_w,
2019 &self.rope.cos,
2020 &self.rope.sin,
2021 v_hm_single_mut,
2022 1,
2023 nkv,
2024 hd,
2025 pos_i,
2026 eps,
2027 0,
2028 );
2029
2030 // KV append + attention for item i's cache.
2031 let caches = self
2032 .kv_caches
2033 .get_mut(cache_id)
2034 .expect("ensure_kv must be called before forward_layer_batched");
2035 let cache = &mut caches[li];
2036 B::kv_cache_append_head_major(
2037 ctx,
2038 &mut cache.k,
2039 &mut cache.v,
2040 cache.len,
2041 cache.capacity,
2042 k_hm_single_mut,
2043 v_hm_single_mut,
2044 1,
2045 nkv,
2046 hd,
2047 );
2048 cache.len += 1;
2049 let kv_len = cache.len;
2050 let kv_stride = cache.capacity;
2051
2052 let attn_cfg = ferrum_kernels::backend::AttnConfig {
2053 num_heads: nh,
2054 num_kv_heads: nkv,
2055 head_dim: hd,
2056 causal: true,
2057 scale: 1.0 / (hd as f32).sqrt(),
2058 kv_seq_stride: kv_stride,
2059 sliding_window: cfg_base.sliding_window,
2060 };
2061 B::flash_attention(
2062 ctx,
2063 q_hm_single_mut,
2064 &cache.k,
2065 &cache.v,
2066 attn_hm_single_mut,
2067 1,
2068 1,
2069 kv_len,
2070 pos_i,
2071 &attn_cfg,
2072 );
2073
2074 // Untranspose head-major → token-major: for tokens=1 the
2075 // layouts are byte-identical, so copy_slice straight into
2076 // attn_flat at the per-item offset (saves a transpose).
2077 B::copy_slice(
2078 ctx,
2079 attn_hm_single_mut,
2080 0,
2081 &mut self.scratch.attn_flat,
2082 i * q_dim,
2083 q_dim,
2084 );
2085 }
2086
2087 // End of per-item attention loop.
2088 stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
2089 } // end of `else` for non-paged path
2090
2091 let post_attn_t0 = stage_t0();
2092
2093 // 7. o_proj GEMM at m=M: attn_flat [M, Q] → o_proj_out [M, H]
2094 attn_layer.o_proj.forward(
2095 ctx,
2096 &self.scratch.attn_flat,
2097 &mut self.scratch.o_proj_out,
2098 m,
2099 );
2100
2101 // 8. fused residual_add + post_attention_layernorm
2102 B::fused_add_rms_norm(
2103 ctx,
2104 residual,
2105 &self.scratch.o_proj_out,
2106 &attn_layer.post_ln_w,
2107 eps,
2108 &mut self.scratch.norm_out,
2109 m,
2110 h,
2111 );
2112
2113 // o_proj + post-norm count under DENSE.
2114 stage_end(post_attn_t0, ctx, &BD_DENSE_US);
2115 let moe_t0 = stage_t0();
2116
2117 // 9. Router gemv: norm_out [M, H] → router_logits [M, n_exp]
2118 let moe_layer = &self.moe_layers[li];
2119 moe_layer.router.forward(
2120 ctx,
2121 &self.scratch.norm_out,
2122 &mut self.scratch.router_logits,
2123 m,
2124 );
2125
2126 // 10. MoE expert dispatch — per-item loop using the cheap
2127 // stacked decode kernels (gemv_quant_moe_id + silu_mul_stacked
2128 // + weighted_sum_batched). NOT the batched prefill path:
2129 // `moe_forward_batched_prefill_impl` is tuned for large M
2130 // (prefill) and the GPU bucketing overhead
2131 // (`compute_ids_tpe_gpu` + indirect-dispatch arg-buffer
2132 // setup) costs more than M sequential gemv calls at small M.
2133 //
2134 // Strategy: route ALL M tokens once via batched
2135 // `route_topk_softmax`, then loop M iterations of the stacked
2136 // decode kernels. Each iteration:
2137 // - extract item i's selected ids + weights from the M-batch
2138 // buffers via copy_slice
2139 // - copy norm_out[i*h..(i+1)*h] → x_single
2140 // - 3× gemv_quant_moe_id (gate/up/down) reading from x_single
2141 // - silu_mul_stacked
2142 // - weighted_sum_batched(batch=1) → acc_buf (fresh write,
2143 // no residual fusion)
2144 // - copy_slice acc_buf → moe_out[i*h..(i+1)*h]
2145 // After the loop, single add_inplace residual += moe_out [M, H].
2146 let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
2147 && moe_layer.experts.up_stacked.is_some()
2148 && moe_layer.experts.down_stacked.is_some();
2149 // MoE FFN dispatch tiers (m = batch size of this layer call):
2150 //
2151 // m = 1 : `moe_forward_stacked_decode_impl`
2152 // (decode m=1 fast path, fused gate+up+silu)
2153 // m ≥ 8 (default): `moe_forward_batched_prefill_impl`
2154 // (GEMM with simdgroup_matmul + GPU bucketing)
2155 // else (m=2..7) : per-item stacked decode loop
2156 //
2157 // EXPERIMENTAL — opt-in `FERRUM_MOE_BATCHED_DECODE=1` engages the
2158 // new `moe_forward_batched_decode_impl` for 2 ≤ m < 32. The
2159 // kernel itself is bitwise correct and ports llama.cpp's
2160 // `kernel_mul_mv_id` strategy to ferrum (one indirect-dispatch
2161 // GEMV per linear covering all m*top_k pairs). Empirically OFF
2162 // by default because the existing `forward_layer_batched_decode`
2163 // attention plumbing (per-item copy_slice × m × 6 dispatches)
2164 // scales linearly with m and overshadows the FFN savings —
2165 // regression measured at -19% (c=4) and -36% (c=16) on
2166 // Qwen3-30B-A3B Q4_K_M / M1 Max. Closing that gap requires a
2167 // batched attention path with offset-aware QKV slicing, which
2168 // is the next PR's job. Until then the kernel sits as
2169 // infrastructure.
2170 // Two independent thresholds:
2171 // * `FERRUM_MOE_BATCH_THRESHOLD` (default 8) — m above which
2172 // the LEGACY non-experimental path uses the prefill GEMM.
2173 // Shared with `decode_batch`'s engine-level gate, so users
2174 // who set it to a small value to engage batched decode
2175 // don't accidentally also push the inner FFN to GEMM.
2176 // * `FERRUM_MOE_PREFILL_THRESHOLD` (default 32) — m above
2177 // which the EXPERIMENTAL batched-decode path defers to the
2178 // prefill GEMM path. Mirrors llama.cpp's `ne21_mm_id_min=32`
2179 // GEMV→GEMM boundary.
2180 let legacy_prefill_threshold: usize = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
2181 .ok()
2182 .and_then(|s| s.parse().ok())
2183 .unwrap_or(8);
2184 let new_prefill_threshold: usize = std::env::var("FERRUM_MOE_PREFILL_THRESHOLD")
2185 .ok()
2186 .and_then(|s| s.parse().ok())
2187 .unwrap_or(32);
2188 // 0.7.2: default to ON when paged-KV is also on (which is now
2189 // the default for Metal). The historical regression for this
2190 // flag (-19% c=4 / -36% c=16) was measured in the pre-paged-KV
2191 // world where `forward_layer_batched_decode`'s per-item
2192 // copy_slice × m × 6 attention dispatches cost more than the
2193 // batched MoE FFN saved. Once paged-KV is on, attention runs as
2194 // one `paged_decode_attention(num_seqs=m)` dispatch, the
2195 // plumbing cost drops, and the batched MoE GEMV's win net out
2196 // to ~+50% at c=16. `FERRUM_MOE_BATCHED_DECODE=0` forces off.
2197 let new_batched_default = stacked_path_available && B::supports_batched_moe_gemv();
2198 let new_batched_enabled = new_batched_default
2199 && std::env::var("FERRUM_MOE_BATCHED_DECODE")
2200 .map(|v| v != "0")
2201 .unwrap_or(true);
2202
2203 // When the new path is opted in, it owns the m=2..new_prefill_threshold
2204 // range; the legacy threshold is overridden upward.
2205 let use_prefill_batched = if new_batched_enabled {
2206 stacked_path_available && m >= new_prefill_threshold
2207 } else {
2208 stacked_path_available && m >= legacy_prefill_threshold
2209 };
2210 let use_batched_decode = new_batched_enabled && !use_prefill_batched && m >= 2;
2211
2212 if use_prefill_batched {
2213 moe_forward_batched_prefill_impl::<B>(
2214 ctx,
2215 moe_layer,
2216 &mut self.scratch,
2217 h,
2218 self.cfg.expert_intermediate_size,
2219 self.cfg.num_experts_per_tok,
2220 self.cfg.num_experts,
2221 self.cfg.norm_topk_prob,
2222 m,
2223 )?;
2224 } else if use_batched_decode {
2225 moe_forward_batched_decode_impl::<B>(
2226 ctx,
2227 moe_layer,
2228 &mut self.scratch,
2229 h,
2230 self.cfg.expert_intermediate_size,
2231 self.cfg.num_experts_per_tok,
2232 self.cfg.num_experts,
2233 self.cfg.norm_topk_prob,
2234 m,
2235 )?;
2236 } else if stacked_path_available {
2237 let inter = self.cfg.expert_intermediate_size;
2238 let top_k = self.cfg.num_experts_per_tok;
2239 let n_exp = self.cfg.num_experts;
2240 let norm_topk_prob = self.cfg.norm_topk_prob;
2241 let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2242 let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2243 let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2244
2245 // Single batched router pass: writes selected_ids_buf [M, top_k]
2246 // and weights_2d [M, top_k]. Replaces M individual route calls.
2247 B::route_topk_softmax(
2248 ctx,
2249 &self.scratch.router_logits,
2250 &mut self.scratch.selected_ids_buf,
2251 &mut self.scratch.weights_2d,
2252 m,
2253 n_exp,
2254 top_k,
2255 norm_topk_prob,
2256 )?;
2257
2258 // Per-item loop using offset-aware kernel APIs — eliminates
2259 // the 4 copy_slice round-trips per iteration that the
2260 // earlier implementation needed (ids, weights, x_single,
2261 // moe_out). At c=16 / 48 layers that's ~3,072 dispatches
2262 // saved per token. Uses `gemv_quant_moe_id_offset` to read
2263 // `selected_ids_buf` at the i-th `top_k` block and
2264 // `norm_out` at the i-th hidden row directly. Falls back
2265 // to copy_slice path if backend doesn't support offsets.
2266 for i in 0..m {
2267 let ids_offset = i * top_k;
2268 let activation_offset = i * h;
2269 let weights_offset = i * top_k;
2270 let moe_out_offset = i * h;
2271
2272 // Stacked gate / up gemvs — broadcast item i's row of
2273 // norm_out across top_k slots, read item i's ids.
2274 let gate_res = B::gemv_quant_moe_id_offset(
2275 ctx,
2276 &self.scratch.norm_out,
2277 activation_offset,
2278 gate_stacked,
2279 &self.scratch.selected_ids_buf,
2280 ids_offset,
2281 &mut self.scratch.gate_out_stacked,
2282 top_k,
2283 0,
2284 );
2285 if gate_res.is_err() {
2286 // Backend doesn't support offset variants — fall back
2287 // to the legacy copy_slice path. Same as before.
2288 B::copy_slice(
2289 ctx,
2290 &self.scratch.selected_ids_buf,
2291 ids_offset,
2292 &mut self.scratch.ids_buf,
2293 0,
2294 top_k,
2295 );
2296 B::copy_slice(
2297 ctx,
2298 &self.scratch.weights_2d,
2299 weights_offset,
2300 &mut self.scratch.weights_buf,
2301 0,
2302 top_k,
2303 );
2304 B::copy_slice(
2305 ctx,
2306 &self.scratch.norm_out,
2307 activation_offset,
2308 &mut self.scratch.x_single,
2309 0,
2310 h,
2311 );
2312 B::gemv_quant_moe_id(
2313 ctx,
2314 &self.scratch.x_single,
2315 gate_stacked,
2316 &self.scratch.ids_buf,
2317 &mut self.scratch.gate_out_stacked,
2318 top_k,
2319 0,
2320 )?;
2321 B::gemv_quant_moe_id(
2322 ctx,
2323 &self.scratch.x_single,
2324 up_stacked,
2325 &self.scratch.ids_buf,
2326 &mut self.scratch.up_out_stacked,
2327 top_k,
2328 0,
2329 )?;
2330 B::silu_mul_stacked(
2331 ctx,
2332 &self.scratch.gate_out_stacked,
2333 &self.scratch.up_out_stacked,
2334 &mut self.scratch.silu_stacked,
2335 top_k,
2336 inter,
2337 )?;
2338 B::gemv_quant_moe_id(
2339 ctx,
2340 &self.scratch.silu_stacked,
2341 down_stacked,
2342 &self.scratch.ids_buf,
2343 &mut self.scratch.down_out_stacked,
2344 top_k,
2345 inter,
2346 )?;
2347 B::weighted_sum_batched(
2348 ctx,
2349 &self.scratch.down_out_stacked,
2350 &self.scratch.weights_buf,
2351 &mut self.scratch.acc_buf,
2352 1,
2353 top_k,
2354 h,
2355 )?;
2356 B::copy_slice(
2357 ctx,
2358 &self.scratch.acc_buf,
2359 0,
2360 &mut self.scratch.moe_out,
2361 moe_out_offset,
2362 h,
2363 );
2364 continue;
2365 }
2366 // Fast path: offset-aware all the way through.
2367 B::gemv_quant_moe_id_offset(
2368 ctx,
2369 &self.scratch.norm_out,
2370 activation_offset,
2371 up_stacked,
2372 &self.scratch.selected_ids_buf,
2373 ids_offset,
2374 &mut self.scratch.up_out_stacked,
2375 top_k,
2376 0,
2377 )?;
2378 B::silu_mul_stacked(
2379 ctx,
2380 &self.scratch.gate_out_stacked,
2381 &self.scratch.up_out_stacked,
2382 &mut self.scratch.silu_stacked,
2383 top_k,
2384 inter,
2385 )?;
2386 B::gemv_quant_moe_id_offset(
2387 ctx,
2388 &self.scratch.silu_stacked,
2389 0, // silu_stacked itself stays at offset 0 each iter
2390 down_stacked,
2391 &self.scratch.selected_ids_buf,
2392 ids_offset,
2393 &mut self.scratch.down_out_stacked,
2394 top_k,
2395 inter,
2396 )?;
2397 // Write directly into moe_out at the per-item offset —
2398 // skips the copy_slice from acc_buf.
2399 B::weighted_sum_batched_offset(
2400 ctx,
2401 &self.scratch.down_out_stacked,
2402 &self.scratch.weights_2d,
2403 weights_offset,
2404 &mut self.scratch.moe_out,
2405 moe_out_offset,
2406 1,
2407 top_k,
2408 h,
2409 )?;
2410 }
2411 } else {
2412 // Backend without stacked variants — fall back to the legacy
2413 // per-(token, expert) host-routed path.
2414 moe_forward::<B>(
2415 ctx,
2416 &self.scratch.norm_out,
2417 &self.scratch.router_logits,
2418 &mut self.scratch.moe_out,
2419 m,
2420 h,
2421 self.cfg.expert_intermediate_size,
2422 self.cfg.num_experts,
2423 self.cfg.num_experts_per_tok,
2424 self.cfg.norm_topk_prob,
2425 &moe_layer.experts,
2426 &mut self.scratch.x_single,
2427 &mut self.scratch.acc_buf,
2428 &mut self.scratch.gate_up_buf,
2429 &mut self.scratch.silu_buf,
2430 &mut self.scratch.down_buf,
2431 &self.scratch.zero_hidden,
2432 )?;
2433 }
2434
2435 // 11. residual += moe_out [M, H]
2436 B::add_inplace(ctx, residual, &self.scratch.moe_out, m * h);
2437
2438 // Close MoE-block instrumentation (router + FFN + residual add).
2439 stage_end(moe_t0, ctx, &BD_MOE_US);
2440
2441 Ok(())
2442 }
2443}
2444
2445impl<B: Backend> DecoderOnlyLLM for Qwen3MoeModel<B> {
2446 fn config(&self) -> &LlmRuntimeConfig {
2447 &self.runtime_cfg
2448 }
2449
2450 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2451 // Eager scratch + KV cache grow + a 1-token forward warmup so
2452 // the first real prefill / decode doesn't pay the cold-start
2453 // ~25-MTLBuffer scratch alloc + ~96-MTLBuffer KV alloc + Metal
2454 // pipeline-state first-bind costs (~265 ms total on Qwen3-MoE
2455 // 30B-A3B / M1 Max). Mirrors what llama-bench's --warmup does
2456 // (which runs a same-shape forward before the timer).
2457 self.ensure_scratch(max_tokens);
2458 self.ensure_kv(cache_id);
2459
2460 // Warmup forward through all 48 layers under a scratch cache_id
2461 // so the real `cache_id` starts at pos_offset=0. Token 0 is
2462 // valid for any tokenizer (BOS or pad).
2463 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2464 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2465 // Drop the warmup KV cache slot — real cache_id is unaffected.
2466 if let Some(caches) = self.kv_caches.remove(WARMUP_CACHE) {
2467 self.kv_free_pool.push(caches);
2468 }
2469 }
2470
2471 fn kv_capacity(&self) -> usize {
2472 // Mirror the bound `ensure_kv` will use when allocating the cache.
2473 let model_max = self.cfg.base.max_seq_len;
2474 const DEFAULT_KV_CAPACITY: usize = 512;
2475 std::env::var("FERRUM_KV_CAPACITY")
2476 .ok()
2477 .and_then(|s| s.parse::<usize>().ok())
2478 .map(|cap| cap.min(model_max))
2479 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
2480 }
2481
2482 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2483 self.prefill_internal(cache_id, tokens)
2484 }
2485
2486 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2487 self.decode_internal(cache_id, token, pos)
2488 }
2489
2490 // decode_batch is gated to use the batched path only when it's a
2491 // measurable win. The crossover depends on M:
2492 //
2493 // - At low M (≤ ~8) the per-item `decode_internal` loop wins
2494 // because: (a) it stays at scratch offset 0 (no copy_slice
2495 // overhead), (b) it preserves the cross-layer rms_norm fusion
2496 // fast path (`weighted_sum_residual_norm_stacked`).
2497 // - At high M (≥ ~12) the batched path wins because the dense
2498 // GEMM batching (qkv_proj, o_proj, router, lm_head at m=M) and
2499 // the prefill-batched MoE dispatch (one `gemm_quant_moe_id` for
2500 // all tokens) amortise the ~48-dispatch lost-fusion penalty.
2501 //
2502 // Default opted out of FERRUM_MOE_BATCHED. When opted in, the
2503 // batched path engages only at M ≥ FERRUM_MOE_BATCH_THRESHOLD
2504 // (default 12). Below that we still go per-item.
2505 //
2506 // Empirical note 2026-05-02: a follow-up PR added a batched MoE
2507 // GEMV kernel (`gemv_quant_moe_id_batched`) that holds MoE
2508 // dispatch count flat as concurrency scales. Wiring it through
2509 // `decode_batch_internal` regressed throughput by 19% (c=4) /
2510 // 36% (c=16) — `forward_layer_batched_decode`'s per-item
2511 // attention plumbing (copy_slice × m × 6 dispatches) costs more
2512 // than the MoE save. The batched MoE kernel is shipped as opt-in
2513 // infrastructure (`FERRUM_MOE_BATCHED_DECODE=1`); flipping it on
2514 // by default has to wait until the attention plumbing is fixed.
2515 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2516 let m = batch.len();
2517 // Default ON in 0.7.2+. The threshold (default 8) keeps small-m
2518 // requests on the per-token loop where it still wins on this
2519 // hardware — see docs/bench/macos-2026-05-02 for the crossover
2520 // measurements (c=4 batched 39 < per_token 42; c=8 batched 59 >
2521 // per_token 47). `FERRUM_MOE_BATCHED=0` forces the legacy loop.
2522 let opted_in = std::env::var("FERRUM_MOE_BATCHED")
2523 .map(|v| v != "0")
2524 .unwrap_or(true);
2525 let threshold = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
2526 .ok()
2527 .and_then(|s| s.parse::<usize>().ok())
2528 .unwrap_or(8);
2529 if opted_in && m >= threshold {
2530 self.decode_batch_internal(batch)
2531 } else {
2532 batch
2533 .iter()
2534 .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
2535 .collect()
2536 }
2537 }
2538
2539 fn release(&mut self, cache_id: &str) {
2540 let mut ctx = B::new_context();
2541 B::sync(&mut ctx);
2542 B::reset_graph(&mut ctx);
2543 B::sync(&mut ctx);
2544 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2545 // Paged mode: return the cache_id's blocks to the shared
2546 // allocator so other sequences can reuse them. Without this,
2547 // every request consumes max_blocks_per_seq blocks
2548 // permanently — pool exhausts after FERRUM_PAGED_MAX_SEQS
2549 // requests and subsequent ensure_kv panics with
2550 // "scratch residual missing" (the cascade panic from a
2551 // failed ensure_kv path leaving scratch poisoned).
2552 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2553 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2554 if let Some(c0) = caches.first() {
2555 if !c0.paged_block_indices.is_empty() {
2556 alloc.free(&c0.paged_block_indices);
2557 }
2558 }
2559 for c in caches.iter_mut() {
2560 c.paged_block_indices.clear();
2561 }
2562 }
2563 self.kv_free_pool.push(caches);
2564 }
2565 }
2566
2567 fn reset(&mut self) {
2568 let mut ctx = B::new_context();
2569 B::sync(&mut ctx);
2570 B::reset_graph(&mut ctx);
2571 B::sync(&mut ctx);
2572 self.kv_caches.clear();
2573 self.kv_free_pool.clear();
2574 }
2575}
2576
2577/// Batched MoE FFN — decode (m=1) and per-token-prefill (m>1 looped).
2578///
2579/// Three batched `gemv_quant_moe_id` dispatches per token: gate (broadcast
2580/// activation), up (broadcast activation), down (per-slot activation —
2581/// each expert sees its own silu·up). The per-(token, expert) outer loop
2582/// shrinks from `top_k * 4` dispatches per layer to **3 batched + 1
2583/// silu_mul_split + 1 weighted_sum_dispatch_loop**.
2584///
2585/// For prefill (m > 1) we loop over tokens externally — each token's
2586/// router output drives a single batched call. Still much faster than
2587/// the per-(token, expert) per-Linear path because the gemvs are batched.
2588///
2589/// Free function (not a method) so the caller can split the borrow on
2590/// `self` between `moe_layers[li]` (immutable) and `scratch` (mutable).
2591#[allow(clippy::too_many_arguments)]
2592fn moe_forward_stacked_decode_impl<B: Backend>(
2593 ctx: &mut B::Context,
2594 moe_layer: &Qwen3MoeLayerState<B>,
2595 scratch: &mut Qwen3MoeScratch<B>,
2596 h: usize,
2597 inter: usize,
2598 top_k: usize,
2599 n_exp: usize,
2600 norm_topk_prob: bool,
2601 tokens: usize,
2602 residual: &mut B::Buffer,
2603 // If `Some`, fold the NEXT layer's leading rms_norm into the
2604 // weighted-sum-residual tail using `weighted_sum_residual_norm_stacked`.
2605 next_norm_w: Option<&B::Buffer>,
2606 eps: f32,
2607) -> Result<()> {
2608 // GPU-side routing: one Metal launch reads router_logits and writes
2609 // selected ids + combine weights directly into device-side scratch
2610 // buffers. Eliminates the per-layer `B::sync + B::to_vec(router_logits)
2611 // + host route()` round trip — the dominant remaining cost in the
2612 // decode hot path (~10% of total decode latency).
2613 let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
2614 let stage_t0 = || -> Option<std::time::Instant> {
2615 if prof {
2616 Some(std::time::Instant::now())
2617 } else {
2618 None
2619 }
2620 };
2621 let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
2622 if let Some(t) = t0 {
2623 B::sync(ctx);
2624 c.fetch_add(
2625 t.elapsed().as_micros() as u64,
2626 std::sync::atomic::Ordering::Relaxed,
2627 );
2628 }
2629 };
2630
2631 let t0 = stage_t0();
2632 B::route_topk_softmax(
2633 ctx,
2634 &scratch.router_logits,
2635 &mut scratch.ids_buf,
2636 &mut scratch.weights_buf,
2637 tokens,
2638 n_exp,
2639 top_k,
2640 norm_topk_prob,
2641 )?;
2642 stage_end(t0, ctx, &DEC_ROUTE_US);
2643
2644 let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2645 let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2646 let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2647
2648 // moe_forward_stacked_decode_impl is only called when `tokens == 1`
2649 // (the branch in `forward_layer` routes prefill m>1 through
2650 // `moe_forward_batched_prefill_impl` instead). The for-b loop and
2651 // the copy norm_out[b*h] → x_single were vestigial scaffolding;
2652 // for tokens=1 norm_out[0..h] IS the activation row, and we can
2653 // pass it straight to the gemv kernel via src1_stride=0 broadcast.
2654 debug_assert_eq!(
2655 tokens, 1,
2656 "moe_forward_stacked_decode_impl expects tokens=1 (prefill goes through moe_forward_batched_prefill_impl)"
2657 );
2658 let _ = tokens; // silence unused-warning when assertion is compiled out
2659
2660 {
2661 // ids_buf and weights_buf populated by the GPU router above —
2662 // no host writes needed here in the decode path.
2663
2664 // Fused-vs-unfused gate+up+silu selection.
2665 //
2666 // Default: when the backend advertises support (Metal Q4KExperts),
2667 // run the single fused dispatch — saves 2 dispatches and the
2668 // entire round-trip through gate_out_stacked / up_out_stacked
2669 // scratch (≈4× [top_k, ffn] of intermediate bandwidth).
2670 //
2671 // Opt-out: `FERRUM_MOE_FUSED_GATE_UP_SILU=0` forces the legacy
2672 // 3-dispatch path. Used for A/B benchmarking and as a kill switch
2673 // if the fused kernel ever produces divergent outputs.
2674 // Cache the env-flag read once per process — the decode hot
2675 // path calls this fn ~48 layers × ~steps_per_run times.
2676 static FUSED_DISABLED: OnceLock<bool> = OnceLock::new();
2677 let fused_disabled = *FUSED_DISABLED
2678 .get_or_init(|| std::env::var("FERRUM_MOE_FUSED_GATE_UP_SILU").as_deref() == Ok("0"));
2679 let use_fused = B::supports_fused_moe_gate_up_silu() && !fused_disabled;
2680
2681 if use_fused {
2682 // 1+2+3 fused: silu_stacked = SiLU(gate · norm_out) * (up · norm_out)
2683 let t0 = stage_t0();
2684 B::gemv_quant_moe_id_gate_up_silu(
2685 ctx,
2686 &scratch.norm_out,
2687 gate_stacked,
2688 up_stacked,
2689 &scratch.ids_buf,
2690 &mut scratch.silu_stacked,
2691 top_k,
2692 )?;
2693 stage_end(t0, ctx, &DEC_SILU_US);
2694 } else {
2695 // 1. Batched gate gemv — broadcast input across top_k slots.
2696 // src1 = norm_out (which has hidden floats at offset 0),
2697 // src1_stride=0 → all slots read the same row.
2698 let t0 = stage_t0();
2699 B::gemv_quant_moe_id(
2700 ctx,
2701 &scratch.norm_out,
2702 gate_stacked,
2703 &scratch.ids_buf,
2704 &mut scratch.gate_out_stacked,
2705 top_k,
2706 0, // broadcast
2707 )?;
2708 stage_end(t0, ctx, &DEC_GATE_US);
2709
2710 // 2. Batched up gemv — also broadcast.
2711 let t0 = stage_t0();
2712 B::gemv_quant_moe_id(
2713 ctx,
2714 &scratch.norm_out,
2715 up_stacked,
2716 &scratch.ids_buf,
2717 &mut scratch.up_out_stacked,
2718 top_k,
2719 0,
2720 )?;
2721 stage_end(t0, ctx, &DEC_UP_US);
2722
2723 // 3. Stacked SiLU·gate → silu_stacked. Single dispatch covers
2724 // all top_k slots — replaces the per-slot loop's
2725 // (3 copy_slice + 1 silu_mul) × 8 = 32 dispatches.
2726 let t0 = stage_t0();
2727 B::silu_mul_stacked(
2728 ctx,
2729 &scratch.gate_out_stacked,
2730 &scratch.up_out_stacked,
2731 &mut scratch.silu_stacked,
2732 top_k,
2733 inter,
2734 )?;
2735 stage_end(t0, ctx, &DEC_SILU_US);
2736 }
2737
2738 // 4. Batched down gemv — per-slot input via src1_stride = inter.
2739 // silu_stacked[k * inter ..] is the activation row for slot k.
2740 let t0 = stage_t0();
2741 B::gemv_quant_moe_id(
2742 ctx,
2743 &scratch.silu_stacked,
2744 down_stacked,
2745 &scratch.ids_buf,
2746 &mut scratch.down_out_stacked,
2747 top_k,
2748 inter,
2749 )?;
2750 stage_end(t0, ctx, &DEC_DOWN_US);
2751
2752 // 5. Fused weighted-sum + residual-add (+ optional next-layer
2753 // rms_norm). Two paths:
2754 //
2755 // * `next_norm_w = Some(_)` (cross-layer fusion): one kernel
2756 // computes residual[i] += Σ_k w[k] · down[k, i] AND
2757 // norm_out[i] = residual[i] · scale · next_norm_w[i].
2758 // The next layer's leading rms_norm is skipped. Saves an
2759 // additional dispatch per layer transition.
2760 // * `next_norm_w = None` (last layer): just residual-add.
2761 let t0 = stage_t0();
2762 if let Some(nnw) = next_norm_w {
2763 B::weighted_sum_residual_norm_stacked(
2764 ctx,
2765 &scratch.down_out_stacked,
2766 &scratch.weights_buf,
2767 residual,
2768 nnw,
2769 &mut scratch.norm_out,
2770 top_k,
2771 h,
2772 eps,
2773 )?;
2774 } else {
2775 B::weighted_sum_residual_stacked(
2776 ctx,
2777 &scratch.down_out_stacked,
2778 &scratch.weights_buf,
2779 residual,
2780 top_k,
2781 h,
2782 )?;
2783 }
2784 stage_end(t0, ctx, &DEC_WSUM_US);
2785 }
2786
2787 Ok(())
2788}
2789
2790/// Batched MoE FFN for prefill (m > 1).
2791///
2792/// One pass through the expert dispatch — replaces the per-token loop
2793/// with three batched 2-D mul_mm_id dispatches (gate, up, down) where
2794/// each expert's slab of (token, slot) pairs runs as one gemm tile.
2795/// Per-layer dispatch count: ~6 (router + 3 mul_mm_id + silu + wsum)
2796/// independent of `tokens`. Compare to the decode-style stacked path
2797/// that emits ~10 per token.
2798///
2799/// Free function so the caller can split the borrow on `self` between
2800/// `moe_layers[li]` (immutable) and `scratch` (mutable).
2801#[allow(clippy::too_many_arguments)]
2802fn moe_forward_batched_prefill_impl<B: Backend>(
2803 ctx: &mut B::Context,
2804 moe_layer: &Qwen3MoeLayerState<B>,
2805 scratch: &mut Qwen3MoeScratch<B>,
2806 h: usize,
2807 inter: usize,
2808 top_k: usize,
2809 n_exp: usize,
2810 norm_topk_prob: bool,
2811 tokens: usize,
2812) -> Result<()> {
2813 let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
2814 let stage_t0 = || -> Option<std::time::Instant> {
2815 if prof {
2816 Some(std::time::Instant::now())
2817 } else {
2818 None
2819 }
2820 };
2821 let stage_end =
2822 |t0: Option<std::time::Instant>, ctx: &mut B::Context, us: &AtomicU64, n: &AtomicU64| {
2823 if let Some(t) = t0 {
2824 B::sync(ctx);
2825 us.fetch_add(
2826 t.elapsed().as_micros() as u64,
2827 std::sync::atomic::Ordering::Relaxed,
2828 );
2829 n.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2830 }
2831 };
2832
2833 // GPU-side routing: keep the whole pipeline device-resident. Two
2834 // dispatches replace the per-layer `B::sync + to_vec(router_logits)
2835 // + host route() + host compute_ids_tpe + write_back` round trip.
2836 //
2837 // 1. `route_topk_softmax` writes selected expert IDs (flat
2838 // `[batch, top_k]`) into `selected_ids_buf` and the post-renorm
2839 // combine weights directly into `weights_2d`.
2840 // 2. `compute_ids_tpe_gpu` buckets those pairs into `tpe_buf` and
2841 // `ids_2d` using device-side atomic_fetch_add slot claims. The
2842 // `ids_2d` row stride is the worst-case `tokens * top_k`; the
2843 // consumer GEMM stops at `tpe[e]` so the over-strided columns
2844 // cost only launch overhead, not real compute.
2845 //
2846 // `FERRUM_MOE_HOST_TOPK=1` → legacy CPU softmax+topk+bucket
2847 // `FERRUM_MOE_DIRECT_DISPATCH=1` → GPU topk but worst-case GEMM grid
2848 // (default) → GPU topk + indirect-dispatched GEMM
2849 // (grid sized from max(tpe[e]))
2850 let use_gpu_topk = std::env::var("FERRUM_MOE_HOST_TOPK").as_deref() != Ok("1");
2851 let use_indirect_dispatch =
2852 use_gpu_topk && std::env::var("FERRUM_MOE_DIRECT_DISPATCH").as_deref() != Ok("1");
2853 let max_per_expert = if use_gpu_topk {
2854 let t0 = stage_t0();
2855 B::route_topk_softmax(
2856 ctx,
2857 &scratch.router_logits,
2858 &mut scratch.selected_ids_buf,
2859 &mut scratch.weights_2d,
2860 tokens,
2861 n_exp,
2862 top_k,
2863 norm_topk_prob,
2864 )?;
2865 B::compute_ids_tpe_gpu(
2866 ctx,
2867 &scratch.selected_ids_buf,
2868 &mut scratch.tpe_buf,
2869 &mut scratch.ids_2d,
2870 &mut scratch.gate_up_args_buf,
2871 &mut scratch.down_args_buf,
2872 tokens,
2873 n_exp,
2874 top_k,
2875 inter,
2876 h,
2877 )?;
2878 stage_end(
2879 t0,
2880 ctx,
2881 &MOE_PREFILL_HOST_TOPK_US,
2882 &MOE_PREFILL_HOST_TOPK_CALLS,
2883 );
2884 // Worst-case ids row stride; matches `dispatch_compute_ids_tpe`.
2885 tokens * top_k
2886 } else {
2887 use ferrum_kernels::moe_host::compute_ids_tpe;
2888 let t0 = stage_t0();
2889 B::sync(ctx);
2890 let logits_host = B::to_vec(&scratch.router_logits, tokens * n_exp);
2891 let route = crate::moe::router::route(&logits_host, tokens, n_exp, top_k, norm_topk_prob);
2892 let (tpe_host, ids_host, max_per_expert) =
2893 compute_ids_tpe(&route.expert_ids, n_exp, tokens, top_k);
2894 B::write_i32_into(&mut scratch.tpe_buf, &tpe_host);
2895 B::write_i32_into(&mut scratch.ids_2d, &ids_host);
2896 B::write_f32_into(&mut scratch.weights_2d, &route.expert_weights);
2897 stage_end(
2898 t0,
2899 ctx,
2900 &MOE_PREFILL_HOST_TOPK_US,
2901 &MOE_PREFILL_HOST_TOPK_CALLS,
2902 );
2903 max_per_expert
2904 };
2905
2906 let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2907 let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2908 let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2909
2910 // 1. Batched gate gemm — one launch covers all (token, expert) pairs.
2911 // src1 layout: [batch, ne11=1, K] (broadcast: each pair reads its
2912 // token's row, slot index ignored).
2913 // dst layout: [batch, top_k, expert_inter] — natural.
2914 let t0 = stage_t0();
2915 if use_indirect_dispatch {
2916 B::gemm_quant_moe_id_indirect(
2917 ctx,
2918 &scratch.norm_out,
2919 gate_stacked,
2920 &scratch.ids_2d,
2921 &scratch.tpe_buf,
2922 &mut scratch.gate_out_stacked,
2923 &scratch.gate_up_args_buf,
2924 1, // ne11 = 1: broadcast
2925 top_k,
2926 max_per_expert,
2927 tokens,
2928 )?;
2929 } else {
2930 B::gemm_quant_moe_id(
2931 ctx,
2932 &scratch.norm_out,
2933 gate_stacked,
2934 &scratch.ids_2d,
2935 &scratch.tpe_buf,
2936 &mut scratch.gate_out_stacked,
2937 1,
2938 top_k,
2939 max_per_expert,
2940 tokens,
2941 )?;
2942 }
2943 stage_end(t0, ctx, &MOE_PREFILL_GATE_US, &MOE_PREFILL_GATE_CALLS);
2944
2945 // 2. Batched up gemm — same shape as gate.
2946 let t0 = stage_t0();
2947 if use_indirect_dispatch {
2948 B::gemm_quant_moe_id_indirect(
2949 ctx,
2950 &scratch.norm_out,
2951 up_stacked,
2952 &scratch.ids_2d,
2953 &scratch.tpe_buf,
2954 &mut scratch.up_out_stacked,
2955 &scratch.gate_up_args_buf,
2956 1,
2957 top_k,
2958 max_per_expert,
2959 tokens,
2960 )?;
2961 } else {
2962 B::gemm_quant_moe_id(
2963 ctx,
2964 &scratch.norm_out,
2965 up_stacked,
2966 &scratch.ids_2d,
2967 &scratch.tpe_buf,
2968 &mut scratch.up_out_stacked,
2969 1,
2970 top_k,
2971 max_per_expert,
2972 tokens,
2973 )?;
2974 }
2975 stage_end(t0, ctx, &MOE_PREFILL_UP_US, &MOE_PREFILL_UP_CALLS);
2976
2977 // 3. SiLU·gate over [tokens * top_k, expert_inter] flat layout.
2978 let total_pairs = tokens * top_k;
2979 let t0 = stage_t0();
2980 B::silu_mul_batched(
2981 ctx,
2982 &scratch.gate_out_stacked,
2983 &scratch.up_out_stacked,
2984 &mut scratch.silu_stacked,
2985 total_pairs,
2986 inter,
2987 )?;
2988 stage_end(t0, ctx, &MOE_PREFILL_SILU_US, &MOE_PREFILL_SILU_CALLS);
2989
2990 // 4. Batched down gemm — src1 is [batch, top_k, expert_inter] from
2991 // silu_stacked. ne11 = top_k → each pair reads its own row.
2992 let t0 = stage_t0();
2993 if use_indirect_dispatch {
2994 B::gemm_quant_moe_id_indirect(
2995 ctx,
2996 &scratch.silu_stacked,
2997 down_stacked,
2998 &scratch.ids_2d,
2999 &scratch.tpe_buf,
3000 &mut scratch.down_out_stacked,
3001 &scratch.down_args_buf,
3002 top_k, // ne11 = top_k: per-slot
3003 top_k,
3004 max_per_expert,
3005 tokens,
3006 )?;
3007 } else {
3008 B::gemm_quant_moe_id(
3009 ctx,
3010 &scratch.silu_stacked,
3011 down_stacked,
3012 &scratch.ids_2d,
3013 &scratch.tpe_buf,
3014 &mut scratch.down_out_stacked,
3015 top_k,
3016 top_k,
3017 max_per_expert,
3018 tokens,
3019 )?;
3020 }
3021 stage_end(t0, ctx, &MOE_PREFILL_DOWN_US, &MOE_PREFILL_DOWN_CALLS);
3022
3023 // 5. Per-batch weighted sum: moe_out[b, h] = Σ_k w[b,k] · down[b,k,h]
3024 let t0 = stage_t0();
3025 B::weighted_sum_batched(
3026 ctx,
3027 &scratch.down_out_stacked,
3028 &scratch.weights_2d,
3029 &mut scratch.moe_out,
3030 tokens,
3031 top_k,
3032 h,
3033 )?;
3034 stage_end(t0, ctx, &MOE_PREFILL_WSUM_US, &MOE_PREFILL_WSUM_CALLS);
3035
3036 Ok(())
3037}
3038
3039/// Batched MoE FFN for the **small-m decode** range (typically c=2..32).
3040///
3041/// Mirrors llama.cpp's `kernel_mul_mv_id` strategy: hold the dispatch
3042/// count flat as concurrency scales by emitting **one** batched GEMV
3043/// per linear (gate / up / down) that covers all `m * top_k`
3044/// (token, expert) pairs in a single Metal launch. Replaces the
3045/// per-token outer loop in `forward_layer` (which emitted ~5
3046/// dispatches × m tokens per layer) with a fixed-shape pipeline.
3047///
3048/// Compared to [`moe_forward_batched_prefill_impl`]:
3049/// * no `compute_ids_tpe_gpu` bucketing kernel (the new pair-indexed
3050/// GEMV reads `selected_ids_buf` directly)
3051/// * uses GEMV not GEMM (better tile utilisation when tokens-per-expert
3052/// is small — at c=16 with top_k=8 each expert sees ~1-3 token rows,
3053/// well below the simdgroup_matmul tile width)
3054/// * fewer Metal dispatches per layer (5: route + 3 gemv + silu + wsum)
3055///
3056/// Per-layer dispatch budget: 5 (independent of m). At c=16 / 48 layers
3057/// that's 240 dispatches per decode step vs the per-token loop's ~3,840.
3058#[allow(clippy::too_many_arguments)]
3059fn moe_forward_batched_decode_impl<B: Backend>(
3060 ctx: &mut B::Context,
3061 moe_layer: &Qwen3MoeLayerState<B>,
3062 scratch: &mut Qwen3MoeScratch<B>,
3063 h: usize,
3064 inter: usize,
3065 top_k: usize,
3066 n_exp: usize,
3067 norm_topk_prob: bool,
3068 tokens: usize,
3069) -> Result<()> {
3070 let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
3071 let stage_t0 = || -> Option<std::time::Instant> {
3072 if prof {
3073 Some(std::time::Instant::now())
3074 } else {
3075 None
3076 }
3077 };
3078 let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
3079 if let Some(t) = t0 {
3080 B::sync(ctx);
3081 c.fetch_add(
3082 t.elapsed().as_micros() as u64,
3083 std::sync::atomic::Ordering::Relaxed,
3084 );
3085 }
3086 };
3087
3088 let total_pairs = tokens * top_k;
3089
3090 // 1. Single batched router pass — fills selected_ids_buf [m * top_k]
3091 // and weights_2d [m * top_k] in one Metal dispatch.
3092 let t0 = stage_t0();
3093 B::route_topk_softmax(
3094 ctx,
3095 &scratch.router_logits,
3096 &mut scratch.selected_ids_buf,
3097 &mut scratch.weights_2d,
3098 tokens,
3099 n_exp,
3100 top_k,
3101 norm_topk_prob,
3102 )?;
3103 stage_end(t0, ctx, &MOE_BATCHED_DECODE_ROUTE_US);
3104
3105 let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
3106 let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
3107 let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
3108
3109 // 2+3+4. Fused gate+up+silu — single Metal dispatch covers all
3110 // m*top_k pairs. Falls back to the 3-dispatch sequence on backends
3111 // that don't have the fused-batched kernel.
3112 if B::supports_batched_moe_gate_up_silu() {
3113 let t0 = stage_t0();
3114 B::gemv_quant_moe_id_gate_up_silu_batched(
3115 ctx,
3116 &scratch.norm_out,
3117 gate_stacked,
3118 up_stacked,
3119 &scratch.selected_ids_buf,
3120 &mut scratch.silu_stacked,
3121 tokens,
3122 top_k,
3123 h, // outer stride: K floats per token
3124 0, // inner stride: 0 (slots within a token broadcast)
3125 )?;
3126 // Charge the whole fused step to the SiLU bucket — keeps the
3127 // profile counter additive with the unfused path's silu line.
3128 stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
3129 } else {
3130 // 2. Batched gate gemv — one launch covers all m*top_k pairs.
3131 let t0 = stage_t0();
3132 B::gemv_quant_moe_id_batched(
3133 ctx,
3134 &scratch.norm_out,
3135 gate_stacked,
3136 &scratch.selected_ids_buf,
3137 &mut scratch.gate_out_stacked,
3138 tokens,
3139 top_k,
3140 h,
3141 0,
3142 )?;
3143 stage_end(t0, ctx, &MOE_BATCHED_DECODE_GATE_US);
3144
3145 // 3. Batched up gemv.
3146 let t0 = stage_t0();
3147 B::gemv_quant_moe_id_batched(
3148 ctx,
3149 &scratch.norm_out,
3150 up_stacked,
3151 &scratch.selected_ids_buf,
3152 &mut scratch.up_out_stacked,
3153 tokens,
3154 top_k,
3155 h,
3156 0,
3157 )?;
3158 stage_end(t0, ctx, &MOE_BATCHED_DECODE_UP_US);
3159
3160 // 4. SiLU·gate.
3161 let t0 = stage_t0();
3162 B::silu_mul_batched(
3163 ctx,
3164 &scratch.gate_out_stacked,
3165 &scratch.up_out_stacked,
3166 &mut scratch.silu_stacked,
3167 total_pairs,
3168 inter,
3169 )?;
3170 stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
3171 }
3172
3173 // 5. Batched down gemv — src1 = silu_stacked [m, top_k, ffn]: each
3174 // pair has its own row, outer = top_k * ffn, inner = ffn.
3175 let t0 = stage_t0();
3176 B::gemv_quant_moe_id_batched(
3177 ctx,
3178 &scratch.silu_stacked,
3179 down_stacked,
3180 &scratch.selected_ids_buf,
3181 &mut scratch.down_out_stacked,
3182 tokens,
3183 top_k,
3184 top_k * inter, // outer: top_k * ffn floats per token
3185 inter, // inner: ffn floats per slot
3186 )?;
3187 stage_end(t0, ctx, &MOE_BATCHED_DECODE_DOWN_US);
3188
3189 // 6. Per-token weighted sum across slots → moe_out [m, h]. Caller
3190 // does residual += moe_out at the end of forward_layer.
3191 let t0 = stage_t0();
3192 B::weighted_sum_batched(
3193 ctx,
3194 &scratch.down_out_stacked,
3195 &scratch.weights_2d,
3196 &mut scratch.moe_out,
3197 tokens,
3198 top_k,
3199 h,
3200 )?;
3201 stage_end(t0, ctx, &MOE_BATCHED_DECODE_WSUM_US);
3202
3203 Ok(())
3204}
3205
3206/// Build a stub Linear<B> with the given shape but zero weights. Used to
3207/// fill the dense `gate_up_proj` / `down_proj` slots in `LlamaFamilyLayer`
3208/// for MoE models — those slots are never invoked because the MoE FFN
3209/// path runs through `moe_layer.experts` instead. The stub's only purpose
3210/// is to satisfy the struct's type signature with minimal memory cost.
3211fn stub_linear<B: Backend>(
3212 out_features: usize,
3213 in_features: usize,
3214) -> Box<dyn ferrum_quantization::Linear<B>> {
3215 // Zero-init: out_features * in_features f32. For a 30B-A3B layer
3216 // this is 2*768*2048 = 3.1M f32 = 12 MB → fine; per-layer overhead
3217 // ≈ 12 MB × 48 = 576 MB. Marginal vs the experts (~16 GB).
3218 let zeros = vec![0.0f32; out_features * in_features];
3219 Box::new(ferrum_quantization::DenseLinear::<B>::from_rows(
3220 &zeros,
3221 out_features,
3222 in_features,
3223 ))
3224}
3225
3226fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
3227 let hd = cfg.head_dim;
3228 let half = hd / 2;
3229 let max = cfg.max_seq_len;
3230 let mut cos = vec![0.0f32; max * half];
3231 let mut sin = vec![0.0f32; max * half];
3232 for pos in 0..max {
3233 for i in 0..half {
3234 let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
3235 let angle = pos as f64 * freq;
3236 cos[pos * half + i] = angle.cos() as f32;
3237 sin[pos * half + i] = angle.sin() as f32;
3238 }
3239 }
3240 RopeCache {
3241 cos: B::from_slice(&cos),
3242 sin: B::from_slice(&sin),
3243 }
3244}