ferrum_models/models/qwen3_moe.rs
1//! `Qwen3MoeModel<B>` — Qwen3-MoE family decoder (Qwen3-30B-A3B and friends).
2//!
3//! Architectural delta vs [`LlamaFamilyModel`]:
4//! * Each transformer layer's FFN is a top-K MoE block instead of a
5//! fused `gate_up_proj → silu → down_proj` MLP.
6//! - One small router linear (`[hidden] → [num_experts]`) picks
7//! top-K experts per token.
8//! - Each expert is itself a fused `gate_up + down` MLP with the
9//! same SwiGLU + RMSNorm structure as the dense path, just with
10//! `expert_intermediate_size` (typically much smaller than the
11//! dense `intermediate_size`).
12//! - Output is the weight-summed combination of the K selected
13//! expert outputs.
14//! * Attention path is unchanged from dense Qwen3 (GQA + QK-norm + RoPE).
15//!
16//! Implementation re-uses the dense layer's attention machinery
17//! verbatim — RMSNorm, fused QKV, QK-norm + RoPE, KV cache append,
18//! flash attention, O-projection, residual + post-norm. The only new
19//! code is the MoE FFN block at the tail of each layer's forward.
20//!
21//! Memory model: experts are loaded as `QuantLinear<B>` per expert,
22//! slicing the on-disk 3-D `ffn_{gate,up,down}_exps.weight` tensors
23//! byte-wise so weights stay compressed (Q4_K / Q6_K). For a 32 GB
24//! Mac to run Qwen3-30B-A3B at all, this is non-negotiable: an
25//! eager-fp32 expert stack would weigh ~110 GB.
26
27use std::collections::HashMap;
28use std::sync::atomic::AtomicU64;
29use std::sync::OnceLock;
30
31use ferrum_kernels::backend::{Backend, KvCache};
32use ferrum_quantization::WeightLoader;
33use ferrum_types::{FerrumError, Result};
34
35use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
36use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyLayer, RopeCache};
37use crate::moe::{moe_forward, ExpertStack};
38use crate::moe_config::Qwen3MoeConfig;
39
40// Decode-side per-op profile counters — same names as the dense path
41// so existing tooling (`FERRUM_DECODE_OP_PROFILE=1` log scrapers) keeps
42// working without a separate switch for MoE.
43static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
44static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
45static MOE_TIME_US: AtomicU64 = AtomicU64::new(0);
46static MOE_CALLS: AtomicU64 = AtomicU64::new(0);
47
48// Fine-grained decode-only counters, populated by
49// `moe_forward_stacked_decode_impl` when FERRUM_DECODE_OP_PROFILE is set.
50// Each is per-layer summed over the layers in one decode token; drained
51// at the bottom of `decode_internal`.
52static DEC_ROUTE_US: AtomicU64 = AtomicU64::new(0);
53static DEC_GATE_US: AtomicU64 = AtomicU64::new(0);
54static DEC_UP_US: AtomicU64 = AtomicU64::new(0);
55static DEC_SILU_US: AtomicU64 = AtomicU64::new(0);
56static DEC_DOWN_US: AtomicU64 = AtomicU64::new(0);
57static DEC_WSUM_US: AtomicU64 = AtomicU64::new(0);
58// Single-shot per decode token (not per-layer).
59static DEC_EMBED_US: AtomicU64 = AtomicU64::new(0);
60static DEC_FINAL_NORM_US: AtomicU64 = AtomicU64::new(0);
61static DEC_LM_HEAD_US: AtomicU64 = AtomicU64::new(0);
62
63// MoE batched-prefill sub-stage counters (gate / up / down mul_mm_id +
64// silu + weighted_sum + host topk). Same FERRUM_DECODE_OP_PROFILE gate.
65static MOE_PREFILL_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
66static MOE_PREFILL_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
67static MOE_PREFILL_GATE_US: AtomicU64 = AtomicU64::new(0);
68static MOE_PREFILL_GATE_CALLS: AtomicU64 = AtomicU64::new(0);
69static MOE_PREFILL_UP_US: AtomicU64 = AtomicU64::new(0);
70static MOE_PREFILL_UP_CALLS: AtomicU64 = AtomicU64::new(0);
71static MOE_PREFILL_SILU_US: AtomicU64 = AtomicU64::new(0);
72static MOE_PREFILL_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
73static MOE_PREFILL_DOWN_US: AtomicU64 = AtomicU64::new(0);
74static MOE_PREFILL_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
75static MOE_PREFILL_WSUM_US: AtomicU64 = AtomicU64::new(0);
76static MOE_PREFILL_WSUM_CALLS: AtomicU64 = AtomicU64::new(0);
77
78// MoE batched-DECODE sub-stage counters (small-m path that uses the
79// batched-pair GEMV in place of the per-token loop).
80static MOE_BATCHED_DECODE_ROUTE_US: AtomicU64 = AtomicU64::new(0);
81static MOE_BATCHED_DECODE_GATE_US: AtomicU64 = AtomicU64::new(0);
82static MOE_BATCHED_DECODE_UP_US: AtomicU64 = AtomicU64::new(0);
83static MOE_BATCHED_DECODE_SILU_US: AtomicU64 = AtomicU64::new(0);
84static MOE_BATCHED_DECODE_DOWN_US: AtomicU64 = AtomicU64::new(0);
85static MOE_BATCHED_DECODE_WSUM_US: AtomicU64 = AtomicU64::new(0);
86
87// Coarse stage counters for `forward_layer_batched_decode` so we can
88// see where the time goes without per-op instrumentation. Summed
89// across all layers in one decode_batch_internal call.
90static BD_DENSE_US: AtomicU64 = AtomicU64::new(0); // rms_norm + qkv_proj + split_qkv + o_proj + fused_add_rms_norm
91static BD_ATTN_PERITEM_US: AtomicU64 = AtomicU64::new(0); // the for-i in 0..m attention loop (incl. plumbing)
92static BD_MOE_US: AtomicU64 = AtomicU64::new(0); // router + MoE FFN + residual add
93static BD_LAYER_CALLS: AtomicU64 = AtomicU64::new(0);
94
95/// Per-layer MoE state: router linear (small) + per-expert MLP stack.
96pub struct Qwen3MoeLayerState<B: Backend> {
97 /// Router projection `[hidden] → [num_experts]` — tiny, never sparse,
98 /// always runs the full GEMV.
99 pub router: Box<dyn ferrum_quantization::Linear<B>>,
100 /// Per-expert weight stack. Each entry's `gate_up` is the fused
101 /// `[gate; up]` projection; `down` is the post-SwiGLU output proj.
102 pub experts: ExpertStack<B>,
103}
104
105/// Reusable scratch buffers for the MoE forward path. All sized at
106/// allocation time and reused across layers / forward calls.
107pub struct Qwen3MoeScratch<B: Backend> {
108 /// See [`crate::models::llama_family::LlamaFamilyScratch`] for the
109 /// attention scratch — we re-use those names verbatim.
110 pub residual: Option<B::Buffer>,
111 pub norm_out: B::Buffer,
112 pub qkv_out: B::Buffer,
113 pub q_buf: B::Buffer,
114 pub k_buf: B::Buffer,
115 pub v_buf: B::Buffer,
116 pub q_head_major: B::Buffer,
117 pub k_head_major: B::Buffer,
118 pub v_head_major: B::Buffer,
119 pub attn_head_major_out: B::Buffer,
120 pub attn_flat: B::Buffer,
121 pub o_proj_out: B::Buffer,
122
123 // ── MoE-specific scratch ─────────────────────────────────────────
124 /// Router logits for the whole batch: `[max_tokens, num_experts]`.
125 pub router_logits: B::Buffer,
126 /// Per-(token, expert) gate||up projection output — `[2 * expert_inter]`.
127 pub gate_up_buf: B::Buffer,
128 /// SiLU(gate) * up scratch — `[expert_inter]`.
129 pub silu_buf: B::Buffer,
130 /// Per-(token, expert) down-projection output — `[hidden]`.
131 pub down_buf: B::Buffer,
132 /// Per-token input row scratch — `[hidden]`. Holds the post-RMSNorm
133 /// activation slice that the per-(expert) gate_up gemv reads, kept
134 /// stable across the entire top_k loop for one token.
135 pub x_single: B::Buffer,
136 /// Per-token output accumulator — `[hidden]`. Holds the running
137 /// `Σ_k weight_k · expert_k(x[b])` sum that grows across the top_k
138 /// loop and is flushed to `moe_out[b]` once per token.
139 pub acc_buf: B::Buffer,
140 /// MoE output `[max_tokens, hidden]`. Zeroed each forward.
141 pub moe_out: B::Buffer,
142 /// Pre-allocated `[hidden]` zero scratch — `acc_buf` is reset to
143 /// this each token without going through `B::from_slice` on the
144 /// hot path.
145 pub zero_hidden: B::Buffer,
146
147 // ── MoE batched-fast-path scratch (Metal `gemv_q*kw_moe_id_f32` /
148 // `gemm_q*kw_moe_id_f32`) ─────────────────────────────────────
149 //
150 // Sized for `max_tokens * top_k * X` so the same buffers cover both
151 // decode (m=1, uses the first `top_k * X` slice) and prefill
152 // (m>1, uses the full `max_tokens * top_k * X`). Decode-only
153 // workloads pay no extra memory because `max_tokens` was 1 there.
154 /// `[max_tokens * top_k * expert_inter]` — gate gemm output per pair.
155 pub gate_out_stacked: B::Buffer,
156 /// `[max_tokens * top_k * expert_inter]` — up gemm output per pair.
157 pub up_out_stacked: B::Buffer,
158 /// `[max_tokens * top_k * expert_inter]` — SiLU(gate)·up per pair.
159 pub silu_stacked: B::Buffer,
160 /// `[max_tokens * top_k * hidden]` — down gemm output per pair.
161 pub down_out_stacked: B::Buffer,
162 /// `[top_k]` i32 expert IDs for the current token (decode reuses;
163 /// prefill writes per-pair indices into `ids_2d` instead).
164 pub ids_buf: B::Buffer,
165 /// `[top_k]` f32 router combine weights for the current decode
166 /// token. Decode hot-path uses `write_f32_into` to update.
167 pub weights_buf: B::Buffer,
168 /// `[max_tokens * top_k]` i32 — flat selected-expert IDs from the
169 /// GPU router for the prefill batch. Consumed by `compute_ids_tpe_gpu`
170 /// to bucket pairs by expert into `tpe_buf` / `ids_2d`.
171 pub selected_ids_buf: B::Buffer,
172 /// `[3]` u32 indirect-dispatch args (`grid_x, grid_y, grid_z`) for
173 /// the gate / up MoE GEMM. Written by `compute_ids_tpe_gpu` so the
174 /// consumer GEMM grid covers exactly `max(tpe[e])` columns instead
175 /// of the worst-case `tokens * top_k`.
176 pub gate_up_args_buf: B::Buffer,
177 /// Same shape as `gate_up_args_buf` but for the down MoE GEMM
178 /// (different `grid_y` because down's `M = hidden_size` vs gate/up's
179 /// `M = expert_intermediate_size`).
180 pub down_args_buf: B::Buffer,
181 /// `[num_experts * max_per_expert_max]` i32 — per-expert pair
182 /// index lists for prefill 2-D mul_mm_id. `max_per_expert_max`
183 /// is bounded by `max_tokens * top_k` (worst-case: one expert
184 /// gets every pair). Sized at scratch alloc time.
185 pub ids_2d: B::Buffer,
186 /// `[num_experts]` i32 — `tpe[e]` = number of pairs assigned to
187 /// expert `e`. Companion to `ids_2d`.
188 pub tpe_buf: B::Buffer,
189 /// `[max_tokens * top_k]` f32 — combine weights per pair, in
190 /// natural `[batch, top_k]` layout for `weighted_sum_batched`.
191 pub weights_2d: B::Buffer,
192
193 // ── Final-token / lm_head outputs ────────────────────────────────
194 pub last_hidden: B::Buffer,
195 pub last_normed: B::Buffer,
196 pub logits: B::Buffer,
197 pub batch_logits: B::Buffer,
198
199 // ── Per-item single-token buffers for decode_batch (Phase 4b) ────
200 //
201 // The batched-decode path runs M GEMMs at m=M (qkv_proj / o_proj /
202 // router / MoE expert mul_mm_id) but attention stays a per-item loop
203 // (each cache_id has its own contiguous K/V buffer — no way to fan
204 // M items into a single attention dispatch without paged KV). These
205 // 1-token-shaped scratches hold the per-item slice during the loop:
206 // `copy_slice` extracts q/k/v from the batched buffers, qk_norm_rope
207 // writes head-major into _single, kv_cache_append + flash_attention
208 // run on it, then copy_slice writes back into attn_flat[i*q_dim].
209 //
210 // None until `enable_batched_decode_scratch` is called from
211 // `ensure_kv` once we know we'll be doing multi-seq decode.
212 pub q_single: Option<B::Buffer>,
213 pub k_single: Option<B::Buffer>,
214 pub v_single: Option<B::Buffer>,
215 pub q_head_major_single: Option<B::Buffer>,
216 pub k_head_major_single: Option<B::Buffer>,
217 pub v_head_major_single: Option<B::Buffer>,
218 pub attn_head_major_single: Option<B::Buffer>,
219
220 // ── Paged batched dispatch scratch ──────────────────────────────────
221 //
222 // Mirrors the same fields on `LlamaFamilyScratch`. `Some` only when
223 // `FERRUM_METAL_PAGED_KV=1` and `enable_paged_batch` was called once
224 // we know the pool dimensions. Sized for `FERRUM_PAGED_MAX_SEQS ×
225 // q_dim` so the multi-seq decode path can fan in M items' Q into a
226 // single batched buffer for one `paged_decode_attention(num_seqs=M)`
227 // call instead of running M sequential m=1 attentions.
228 pub paged_batch_q: Option<B::Buffer>,
229 pub paged_batch_o: Option<B::Buffer>,
230 pub paged_batch_block_tables: Option<B::Buffer>,
231 pub paged_batch_context_lens: Option<B::Buffer>,
232 pub paged_max_blocks_per_seq: usize,
233
234 pub max_tokens: usize,
235}
236
237impl<B: Backend> Qwen3MoeScratch<B> {
238 fn alloc(cfg: &Qwen3MoeConfig, max_tokens: usize) -> Self {
239 let h = cfg.base.hidden_size;
240 let q_dim = cfg.base.num_heads * cfg.base.head_dim;
241 let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
242 let qkv_dim = q_dim + 2 * kv_dim;
243 let t = max_tokens;
244 let inter = cfg.expert_intermediate_size;
245 let n_exp = cfg.num_experts;
246 let vocab = cfg.base.vocab_size;
247 Self {
248 residual: Some(B::alloc(t * h)),
249 norm_out: B::alloc(t * h),
250 qkv_out: B::alloc(t * qkv_dim),
251 q_buf: B::alloc(t * q_dim),
252 k_buf: B::alloc(t * kv_dim),
253 v_buf: B::alloc(t * kv_dim),
254 q_head_major: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
255 k_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
256 v_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
257 attn_head_major_out: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
258 attn_flat: B::alloc(t * q_dim),
259 o_proj_out: B::alloc(t * h),
260 router_logits: B::alloc(t * n_exp),
261 gate_up_buf: B::alloc(2 * inter),
262 silu_buf: B::alloc(inter),
263 down_buf: B::alloc(h),
264 x_single: B::alloc(h),
265 acc_buf: B::alloc(h),
266 moe_out: B::alloc(t * h),
267 zero_hidden: B::from_slice(&vec![0.0f32; h]),
268 gate_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
269 up_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
270 silu_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
271 down_out_stacked: B::alloc(t * cfg.num_experts_per_tok * h),
272 ids_buf: B::from_slice_i32(&vec![0i32; cfg.num_experts_per_tok]),
273 weights_buf: B::from_slice(&vec![0.0f32; cfg.num_experts_per_tok]),
274 selected_ids_buf: B::from_slice_i32(&vec![0i32; t * cfg.num_experts_per_tok]),
275 // 3 u32s per indirect args buffer; allocated as 3 i32s so we
276 // can reuse `from_slice_i32`. The kernel writes them as
277 // `device uint *` and the bit pattern is consumed by
278 // `dispatch_thread_groups_indirect`.
279 gate_up_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
280 down_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
281 ids_2d: B::from_slice_i32(&vec![0i32; n_exp * t * cfg.num_experts_per_tok]),
282 tpe_buf: B::from_slice_i32(&vec![0i32; n_exp]),
283 weights_2d: B::from_slice(&vec![0.0f32; t * cfg.num_experts_per_tok]),
284 last_hidden: B::alloc(h),
285 last_normed: B::alloc(h),
286 logits: B::alloc(vocab),
287 batch_logits: B::alloc(t * vocab),
288 // Lazily-allocated; `enable_batched_decode_scratch` populates
289 // these the first time decode_batch is called with M > 1.
290 q_single: None,
291 k_single: None,
292 v_single: None,
293 q_head_major_single: None,
294 k_head_major_single: None,
295 v_head_major_single: None,
296 attn_head_major_single: None,
297 // Lazily-allocated; `enable_paged_batch` populates these when
298 // FERRUM_METAL_PAGED_KV=1 + we know the pool dimensions.
299 paged_batch_q: None,
300 paged_batch_o: None,
301 paged_batch_block_tables: None,
302 paged_batch_context_lens: None,
303 paged_max_blocks_per_seq: 0,
304 max_tokens: t,
305 }
306 }
307
308 /// Allocate scratch for paged batched dispatch. Mirrors
309 /// `LlamaFamilyScratch::enable_paged_batch`. Idempotent.
310 fn enable_paged_batch(
311 &mut self,
312 cfg: &Qwen3MoeConfig,
313 max_seqs: usize,
314 max_blocks_per_seq: usize,
315 ) {
316 if self.paged_batch_q.is_some() {
317 return;
318 }
319 let q_dim = cfg.base.num_heads * cfg.base.head_dim;
320 self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
321 self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
322 self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
323 self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
324 self.paged_max_blocks_per_seq = max_blocks_per_seq;
325 }
326
327 /// Allocate the per-item single-token scratch buffers used by
328 /// `forward_layer_batched_decode`. Idempotent.
329 fn enable_batched_decode_scratch(&mut self, cfg: &Qwen3MoeConfig) {
330 if self.q_single.is_some() {
331 return;
332 }
333 let q_dim = cfg.base.num_heads * cfg.base.head_dim;
334 let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
335 self.q_single = Some(B::alloc(q_dim));
336 self.k_single = Some(B::alloc(kv_dim));
337 self.v_single = Some(B::alloc(kv_dim));
338 self.q_head_major_single = Some(B::alloc(q_dim));
339 self.k_head_major_single = Some(B::alloc(kv_dim));
340 self.v_head_major_single = Some(B::alloc(kv_dim));
341 self.attn_head_major_single = Some(B::alloc(q_dim));
342 }
343}
344
345/// Qwen3-MoE decoder model.
346///
347/// Holds the same per-layer attention weights as [`LlamaFamilyModel`]
348/// plus a [`Qwen3MoeLayerState`] per layer for the MoE FFN. Routing,
349/// expert dispatch, and weighted combine all happen inside
350/// [`moe_forward`]; this struct only owns the storage and orchestrates
351/// the per-layer call sequence.
352pub struct Qwen3MoeModel<B: Backend> {
353 pub cfg: Qwen3MoeConfig,
354 pub runtime_cfg: LlmRuntimeConfig,
355
356 pub embed: B::Buffer,
357 /// Per-layer attention weights (re-uses dense `LlamaFamilyLayer`).
358 pub attn_layers: Vec<LlamaFamilyLayer<B>>,
359 /// Per-layer MoE state (router + expert stack).
360 pub moe_layers: Vec<Qwen3MoeLayerState<B>>,
361 pub final_norm_w: B::Buffer,
362 pub lm_head: Box<dyn ferrum_quantization::Linear<B>>,
363
364 pub rope: RopeCache<B>,
365 pub scratch: Qwen3MoeScratch<B>,
366
367 pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
368 kv_free_pool: Vec<Vec<KvCache<B>>>,
369
370 // ── Paged-KV multi-seq state ────────────────────────────────────────
371 //
372 // Mirrors `LlamaFamilyModel`. Only populated when
373 // `FERRUM_METAL_PAGED_KV=1`. Kv_caches entries become metadata-only
374 // views (block_table + context_lens) into the shared `paged_pools`.
375 pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
376 pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
377}
378
379impl<B: Backend> Qwen3MoeModel<B> {
380 /// Build a Qwen3-MoE model from a generic `WeightLoader<B>` plus a
381 /// GGUF reader for the experts (which `WeightLoader` doesn't model
382 /// directly — its API is rank-2 only).
383 ///
384 /// `loader` provides: token embedding, attention projections, layer
385 /// norms, lm_head — all the rank-2 weights.
386 /// `gguf` provides: the rank-3 expert tensors, sliced per-expert
387 /// inside [`ExpertStack::load_from_gguf`].
388 pub fn new(
389 cfg: Qwen3MoeConfig,
390 loader: &dyn WeightLoader<B>,
391 gguf: &ferrum_quantization::gguf::GgufFile,
392 ) -> Result<Self> {
393 {
394 let mut ctx = B::new_context();
395 B::reset_graph(&mut ctx);
396 }
397 let rope = build_rope_cache::<B>(&cfg.base);
398 let scratch = Qwen3MoeScratch::alloc(&cfg, 1);
399
400 let embed = loader.load_tensor("model.embed_tokens.weight")?;
401
402 let mut attn_layers = Vec::with_capacity(cfg.base.num_layers);
403 let mut moe_layers = Vec::with_capacity(cfg.base.num_layers);
404 for li in 0..cfg.base.num_layers {
405 let prefix = format!("model.layers.{li}");
406 let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
407 let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
408 let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
409 let post_ln_w =
410 loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
411
412 // Dense gate_up_proj / down_proj are absent in MoE GGUFs —
413 // we synthesise stub Linears so the LlamaFamilyLayer struct
414 // type-checks. They're never invoked because forward_layer
415 // calls the MoE path. Cheap: tiny zero-sized DenseLinears.
416 let gate_up_proj: Box<dyn ferrum_quantization::Linear<B>> =
417 stub_linear::<B>(2 * cfg.expert_intermediate_size, cfg.base.hidden_size);
418 let down_proj: Box<dyn ferrum_quantization::Linear<B>> =
419 stub_linear::<B>(cfg.base.hidden_size, cfg.expert_intermediate_size);
420
421 let (q_norm_w, k_norm_w) = if cfg.base.has_qk_norm {
422 let q = loader
423 .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
424 .ok();
425 let k = loader
426 .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
427 .ok();
428 (q, k)
429 } else {
430 (None, None)
431 };
432
433 attn_layers.push(LlamaFamilyLayer {
434 input_ln_w,
435 qkv_proj,
436 q_norm_w,
437 k_norm_w,
438 o_proj,
439 post_ln_w,
440 gate_up_proj,
441 down_proj,
442 });
443
444 // Router lives at `model.layers.{li}.mlp.router.weight` in
445 // ferrum-name space (see ferrum_to_gguf mapping). It's a
446 // plain rank-2 linear so the standard loader path covers
447 // it without going through the MoE-specific GGUF helper.
448 let router = loader.load_linear(&format!("{prefix}.mlp.router"))?;
449 if router.in_features() != cfg.base.hidden_size {
450 return Err(FerrumError::model(format!(
451 "router layer {li}: in_features {} != hidden {}",
452 router.in_features(),
453 cfg.base.hidden_size
454 )));
455 }
456 if router.out_features() != cfg.num_experts {
457 return Err(FerrumError::model(format!(
458 "router layer {li}: out_features {} != num_experts {}",
459 router.out_features(),
460 cfg.num_experts
461 )));
462 }
463
464 let experts = ExpertStack::<B>::load_from_gguf(
465 gguf,
466 li,
467 cfg.num_experts,
468 cfg.base.hidden_size,
469 cfg.expert_intermediate_size,
470 )?;
471
472 moe_layers.push(Qwen3MoeLayerState { router, experts });
473 }
474
475 let final_norm_w = loader.load_tensor("model.norm.weight")?;
476 let lm_head = if loader.has_tensor("lm_head.weight") {
477 loader.load_linear("lm_head")?
478 } else {
479 // Tied embeddings — same as dense path.
480 tracing::info!(
481 "Qwen3MoeModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
482 );
483 loader.load_linear("model.embed_tokens")?
484 };
485
486 let runtime_cfg = cfg.base.to_runtime();
487 Ok(Self {
488 cfg,
489 runtime_cfg,
490 embed,
491 attn_layers,
492 moe_layers,
493 final_norm_w,
494 lm_head,
495 rope,
496 scratch,
497 kv_caches: HashMap::new(),
498 kv_free_pool: Vec::new(),
499 paged_pools: None,
500 paged_block_alloc: None,
501 })
502 }
503
504 pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
505 if self.scratch.max_tokens < tokens {
506 {
507 let mut ctx = B::new_context();
508 B::reset_graph(&mut ctx);
509 }
510 self.scratch = Qwen3MoeScratch::alloc(&self.cfg, tokens);
511 }
512 }
513
514 pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
515 if self.kv_caches.contains_key(cache_id) {
516 return;
517 }
518 let nkv = self.cfg.base.num_kv_heads;
519 let hd = self.cfg.base.head_dim;
520 // See `LlamaFamilyModel::ensure_kv` for the rationale on the 4096
521 // chat-friendly default and how `FERRUM_KV_CAPACITY` overrides it.
522 let model_max = self.cfg.base.max_seq_len;
523 const DEFAULT_KV_CAPACITY: usize = 4096;
524 let max = std::env::var("FERRUM_KV_CAPACITY")
525 .ok()
526 .and_then(|s| s.parse::<usize>().ok())
527 .map(|cap| cap.min(model_max))
528 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
529
530 // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches caches into
531 // block-table-indirect layout. Mirrors LlamaFamilyModel's path so
532 // the existing `paged_decode_attention` Metal kernel can fire
533 // once at num_seqs=m for batched decode (replacing the per-item
534 // attention loop that currently dominates `attn_peritem` in the
535 // c=16 profile).
536 let paged = std::env::var("FERRUM_METAL_PAGED_KV")
537 .map(|v| v == "1")
538 .unwrap_or(false);
539 const PAGED_BLOCK_SIZE: usize = 16;
540
541 let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
542 .ok()
543 .and_then(|s| s.parse::<usize>().ok())
544 .unwrap_or(16);
545 let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
546 let total_pool_blocks = max_seqs * max_blocks_per_seq;
547
548 // Lazy-allocate the shared paged pools on the first paged
549 // ensure_kv call.
550 if paged && self.paged_pools.is_none() {
551 let mut pools = Vec::with_capacity(self.cfg.base.num_layers);
552 for _ in 0..self.cfg.base.num_layers {
553 let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
554 pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
555 }
556 self.paged_pools = Some(pools);
557 self.paged_block_alloc = Some(std::sync::Mutex::new(
558 crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
559 ));
560 }
561 if paged {
562 self.scratch
563 .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
564 }
565
566 let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
567 (0..self.cfg.base.num_layers)
568 .map(|_| {
569 if paged {
570 // Paged mode: cache holds metadata only. K/V are
571 // 1-element placeholders. Real data lives in
572 // `self.paged_pools[li].{k,v}`.
573 let mut block_table = B::alloc_u32(max_blocks_per_seq);
574 let _ = &mut block_table; // suppress unused-mut on backends that no-op write_u32
575 let mut context_lens = B::alloc_u32(1);
576 let mut bt_ctx = B::new_context();
577 B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
578 B::sync(&mut bt_ctx);
579 KvCache {
580 k: B::alloc(1),
581 v: B::alloc(1),
582 len: 0,
583 capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
584 num_kv_heads: nkv,
585 head_dim: hd,
586 block_size: PAGED_BLOCK_SIZE,
587 block_table: Some(block_table),
588 context_lens: Some(context_lens),
589 paged_block_indices: Vec::new(),
590 }
591 } else {
592 KvCache {
593 k: B::alloc(nkv * max * hd),
594 v: B::alloc(nkv * max * hd),
595 len: 0,
596 capacity: max,
597 num_kv_heads: nkv,
598 head_dim: hd,
599 block_size: 0,
600 block_table: None,
601 context_lens: None,
602 paged_block_indices: Vec::new(),
603 }
604 }
605 })
606 .collect()
607 });
608
609 // Allocate physical blocks for THIS cache_id from the shared pool.
610 if paged {
611 let alloc_arc = self
612 .paged_block_alloc
613 .as_ref()
614 .expect("paged_block_alloc must be initialised when paged=true");
615 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
616 let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
617 Ok(idx) => idx,
618 Err(e) => {
619 drop(alloc);
620 self.kv_free_pool.push(caches);
621 eprintln!(
622 "[ferrum] paged KV pool exhausted on ensure_kv for \
623 cache_id={cache_id:?}: {e}. Increase \
624 FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
625 throttle concurrent requests.",
626 );
627 return;
628 }
629 };
630 let mut padded = block_indices.clone();
631 padded.resize(max_blocks_per_seq, 0);
632 let mut ctx_tmp = B::new_context();
633 for c in caches.iter_mut() {
634 if let Some(bt) = c.block_table.as_mut() {
635 B::write_u32(&mut ctx_tmp, bt, &padded);
636 }
637 c.paged_block_indices = block_indices.clone();
638 }
639 B::sync(&mut ctx_tmp);
640 }
641
642 for c in caches.iter_mut() {
643 c.len = 0;
644 if let Some(cl) = c.context_lens.as_mut() {
645 let mut ctx_tmp = B::new_context();
646 B::write_u32(&mut ctx_tmp, cl, &[0u32]);
647 B::sync(&mut ctx_tmp);
648 }
649 }
650 self.kv_caches.insert(cache_id.to_string(), caches);
651 }
652
653 /// Run one full transformer layer (attention + MoE FFN).
654 pub(crate) fn forward_layer(
655 &mut self,
656 ctx: &mut B::Context,
657 li: usize,
658 cache_id: &str,
659 residual: &mut B::Buffer,
660 pos_offset: usize,
661 tokens: usize,
662 // If `Some(idx)` and we land on the decode fast path, fold the
663 // next layer's leading rms_norm into this layer's MoE tail
664 // (cross-layer norm fusion). The next layer's caller must pass
665 // `prev_did_norm_fusion = true` so it skips its own rms_norm.
666 next_layer_idx: Option<usize>,
667 // If `true`, skip step 1's input rms_norm — the previous
668 // layer's tail already populated `scratch.norm_out`.
669 prev_did_norm_fusion: bool,
670 ) -> Result<bool> {
671 let cfg_base = &self.cfg.base;
672 let h = cfg_base.hidden_size;
673 let nh = cfg_base.num_heads;
674 let nkv = cfg_base.num_kv_heads;
675 let hd = cfg_base.head_dim;
676 let eps = cfg_base.rms_norm_eps;
677 let q_dim = nh * hd;
678 let kv_dim = nkv * hd;
679 let attn_layer = &self.attn_layers[li];
680 let moe_layer = &self.moe_layers[li];
681
682 let attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
683 B::sync(ctx);
684 Some(std::time::Instant::now())
685 } else {
686 None
687 };
688
689 // 1. Input RMSNorm — skipped when the previous layer's MoE tail
690 // fused this norm via `weighted_sum_residual_norm_stacked`.
691 if !prev_did_norm_fusion {
692 B::rms_norm(
693 ctx,
694 residual,
695 &attn_layer.input_ln_w,
696 eps,
697 &mut self.scratch.norm_out,
698 tokens,
699 h,
700 );
701 }
702
703 // 2. Fused QKV
704 attn_layer.qkv_proj.forward(
705 ctx,
706 &self.scratch.norm_out,
707 &mut self.scratch.qkv_out,
708 tokens,
709 );
710
711 // 3-4. Fused split-QKV + QK-norm + RoPE + head-major transpose.
712 //
713 // One Metal dispatch replaces (split_qkv → 3× qk_norm_rope), the
714 // four-launch chain that used to dominate the attention prelude.
715 // Reads qkv_out once, writes head-major Q/K (norm+RoPE) and V
716 // (transpose only) directly into attention scratch. Saves 3
717 // dispatches per layer (×48 = 144 dispatches per decode token).
718 //
719 // CPU and other backends without the fused kernel return
720 // Unsupported and we fall through to the original four-launch
721 // path. q_buf / k_buf / v_buf stay in scratch because that path
722 // and the per-expert MoE fallback still want them.
723 let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
724 let dummy = &attn_layer.input_ln_w;
725 let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy);
726 let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy);
727
728 // 5. Grab the per-layer KV cache up front — the deepest fused
729 // variant writes K/V straight into it, avoiding a trailing
730 // `kv_cache_append_head_major` dispatch.
731 //
732 // Paged mode: extract a raw pointer to the layer's pool buffers
733 // BEFORE the &mut cache borrow, so we can pass &mut to the
734 // paged kernel below without holding two simultaneous mutable
735 // borrows on `self`. Safety: `paged_pools` is allocated once at
736 // first ensure_kv call and never resized; the only concurrent
737 // mutation is the pool's own kernel writes (sequenced via
738 // command buffers), so the raw pointer remains valid for the
739 // duration of this layer call.
740 let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
741 if let Some(pools) = self.paged_pools.as_mut() {
742 let pool = &mut pools[li];
743 Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
744 } else {
745 None
746 };
747 let caches = self
748 .kv_caches
749 .get_mut(cache_id)
750 .expect("ensure_kv must be called before forward_layer");
751 let cache = &mut caches[li];
752 let cache_len_before = cache.len;
753 let cache_capacity = cache.capacity;
754
755 // Defense in depth: refuse to write past the KV buffer. Silent
756 // overflow has visible failure modes (garbage output, stale token
757 // attention, slowdowns from reading uninitialised memory). The
758 // graceful path is the caller pre-checking via `kv_capacity()` and
759 // either compacting or refusing the request; this panic only
760 // fires when that contract is broken.
761 if cache_len_before + tokens > cache_capacity {
762 panic!(
763 "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
764 cache_len_before + tokens
765 );
766 }
767
768 // Try the deepest fusion: fused split-QKV-norm-rope that writes
769 // K/V directly into the cache slot. Paged mode writes into the
770 // shared pool via block_table indirection; contiguous mode
771 // writes into the per-cache_id k/v buffers directly.
772 let used_qkv_into_cache = if cache.block_size > 0 {
773 // Paged path.
774 let bt = cache
775 .block_table
776 .as_ref()
777 .expect("paged cache missing block_table");
778 let num_blocks_per_seq = cache.capacity / cache.block_size;
779 let (pool_k_ptr, pool_v_ptr) =
780 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
781 // SAFETY: pools allocated-once, see paged_pool_ptr setup above.
782 let pool_k = unsafe { &mut *pool_k_ptr };
783 let pool_v = unsafe { &mut *pool_v_ptr };
784 B::split_qkv_norm_rope_into_paged_cache(
785 ctx,
786 &self.scratch.qkv_out,
787 0,
788 q_norm_w,
789 k_norm_w,
790 &self.rope.cos,
791 &self.rope.sin,
792 &mut self.scratch.q_head_major,
793 0,
794 pool_k,
795 pool_v,
796 bt,
797 tokens,
798 nh,
799 nkv,
800 hd,
801 pos_offset,
802 eps,
803 qk_mode,
804 cache_len_before,
805 cache.block_size,
806 num_blocks_per_seq,
807 )
808 .is_ok()
809 } else {
810 B::split_qkv_norm_rope_into_cache(
811 ctx,
812 &self.scratch.qkv_out,
813 q_norm_w,
814 k_norm_w,
815 &self.rope.cos,
816 &self.rope.sin,
817 &mut self.scratch.q_head_major,
818 &mut cache.k,
819 &mut cache.v,
820 tokens,
821 nh,
822 nkv,
823 hd,
824 pos_offset,
825 eps,
826 qk_mode,
827 cache_len_before,
828 cache_capacity,
829 )
830 .is_ok()
831 };
832 if !used_qkv_into_cache {
833 // Fallback 1: fused split-QKV-norm-rope to head-major scratch
834 // (Metal pre-decode-fusion path), then explicit cache append.
835 let used_fused_qkv = B::split_qkv_norm_rope(
836 ctx,
837 &self.scratch.qkv_out,
838 q_norm_w,
839 k_norm_w,
840 &self.rope.cos,
841 &self.rope.sin,
842 &mut self.scratch.q_head_major,
843 &mut self.scratch.k_head_major,
844 &mut self.scratch.v_head_major,
845 tokens,
846 nh,
847 nkv,
848 hd,
849 pos_offset,
850 eps,
851 qk_mode,
852 )
853 .is_ok();
854 if !used_fused_qkv {
855 // Fallback 2: original four-launch chain.
856 B::split_qkv(
857 ctx,
858 &self.scratch.qkv_out,
859 &mut self.scratch.q_buf,
860 &mut self.scratch.k_buf,
861 &mut self.scratch.v_buf,
862 tokens,
863 q_dim,
864 kv_dim,
865 );
866 B::qk_norm_rope(
867 ctx,
868 &self.scratch.q_buf,
869 q_norm_w,
870 &self.rope.cos,
871 &self.rope.sin,
872 &mut self.scratch.q_head_major,
873 tokens,
874 nh,
875 hd,
876 pos_offset,
877 eps,
878 qk_mode,
879 );
880 B::qk_norm_rope(
881 ctx,
882 &self.scratch.k_buf,
883 k_norm_w,
884 &self.rope.cos,
885 &self.rope.sin,
886 &mut self.scratch.k_head_major,
887 tokens,
888 nkv,
889 hd,
890 pos_offset,
891 eps,
892 qk_mode,
893 );
894 B::qk_norm_rope(
895 ctx,
896 &self.scratch.v_buf,
897 dummy,
898 &self.rope.cos,
899 &self.rope.sin,
900 &mut self.scratch.v_head_major,
901 tokens,
902 nkv,
903 hd,
904 pos_offset,
905 eps,
906 0,
907 );
908 }
909 B::kv_cache_append_head_major(
910 ctx,
911 &mut cache.k,
912 &mut cache.v,
913 cache.len,
914 cache.capacity,
915 &self.scratch.k_head_major,
916 &self.scratch.v_head_major,
917 tokens,
918 nkv,
919 hd,
920 );
921 }
922 cache.len += tokens;
923 let kv_len = cache.len;
924 let kv_stride = cache.capacity;
925
926 if cache.block_size > 0 {
927 // Paged decode: read from the shared pool via block_table.
928 let bt = cache
929 .block_table
930 .as_ref()
931 .expect("paged cache missing block_table");
932 let cl_buf = cache
933 .context_lens
934 .as_mut()
935 .expect("paged cache missing context_lens");
936 let num_blocks_per_seq = cache.capacity / cache.block_size;
937 let (pool_k_ptr, pool_v_ptr) =
938 paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
939 // SAFETY: see paged_pool_ptr setup above.
940 let pool_k = unsafe { &*pool_k_ptr };
941 let pool_v = unsafe { &*pool_v_ptr };
942 let final_kv_len = cache.len as u32;
943 B::write_u32(ctx, cl_buf, &[final_kv_len]);
944 B::paged_decode_attention(
945 ctx,
946 &self.scratch.q_head_major,
947 pool_k,
948 pool_v,
949 &mut self.scratch.attn_head_major_out,
950 bt,
951 cl_buf,
952 1, // num_seqs (single-seq m=1 path)
953 nh,
954 nkv,
955 hd,
956 cache.block_size,
957 num_blocks_per_seq,
958 tokens,
959 )
960 .expect("paged_decode_attention");
961 let _ = kv_stride; // consumed by contig path only
962 } else {
963 let attn_cfg = ferrum_kernels::backend::AttnConfig {
964 num_heads: nh,
965 num_kv_heads: nkv,
966 head_dim: hd,
967 causal: true,
968 scale: 1.0 / (hd as f32).sqrt(),
969 kv_seq_stride: kv_stride,
970 sliding_window: cfg_base.sliding_window,
971 };
972 B::flash_attention(
973 ctx,
974 &self.scratch.q_head_major,
975 &cache.k,
976 &cache.v,
977 &mut self.scratch.attn_head_major_out,
978 1,
979 tokens,
980 kv_len,
981 pos_offset,
982 &attn_cfg,
983 );
984 }
985
986 if let Some(t0) = attn_t0 {
987 B::sync(ctx);
988 ATTN_TIME_US.fetch_add(
989 t0.elapsed().as_micros() as u64,
990 std::sync::atomic::Ordering::Relaxed,
991 );
992 ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
993 }
994
995 // 7. transpose head-major → token-major.
996 //
997 // For tokens=1 the two layouts are byte-identical: both
998 // collapse to the flat [heads * head_dim] vector at offset
999 // `head*hd + d`. Skip the dispatch and point o_proj at
1000 // attn_head_major_out directly. Saves 1 dispatch per layer
1001 // (×48 = 48 dispatches per decode token) on Qwen3-30B-A3B.
1002 let attn_token_major = if tokens == 1 {
1003 &self.scratch.attn_head_major_out
1004 } else {
1005 B::transpose_head_to_token(
1006 ctx,
1007 &self.scratch.attn_head_major_out,
1008 &mut self.scratch.attn_flat,
1009 tokens,
1010 nh,
1011 hd,
1012 );
1013 &self.scratch.attn_flat
1014 };
1015
1016 // 8. O-proj.
1017 attn_layer
1018 .o_proj
1019 .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1020
1021 // 9. fused residual-add + post-attention RMSNorm.
1022 B::fused_add_rms_norm(
1023 ctx,
1024 residual,
1025 &self.scratch.o_proj_out,
1026 &attn_layer.post_ln_w,
1027 eps,
1028 &mut self.scratch.norm_out,
1029 tokens,
1030 h,
1031 );
1032
1033 // ── MoE FFN block ────────────────────────────────────────────
1034 let moe_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1035 B::sync(ctx);
1036 Some(std::time::Instant::now())
1037 } else {
1038 None
1039 };
1040
1041 // 10. Router gemv: norm_out [tokens, hidden] → router_logits [tokens, num_experts]
1042 moe_layer.router.forward(
1043 ctx,
1044 &self.scratch.norm_out,
1045 &mut self.scratch.router_logits,
1046 tokens,
1047 );
1048
1049 // 11. Per-(token, expert) MLP dispatch + weighted combine.
1050 //
1051 // Two paths:
1052 // - **Batched fast path** (decode m=1, all stacked variants
1053 // present): single `gemv_quant_moe_id` dispatch covers all
1054 // 8 selected expert × 1 token gate gemvs in parallel; same
1055 // for up and down. Cuts per-layer expert dispatches from
1056 // ~32 (8 × 4 ops/pair) to 4 (gate + up + silu + down + 1 acc).
1057 // Routes Qwen3-30B-A3B decode close to llama.cpp's
1058 // `kernel_mul_mm_id`.
1059 // - **Per-(token, expert) fallback** via `moe_forward` —
1060 // used for prefill (m > 1), or when the backend doesn't
1061 // populate stacked variants (CPU, synthetic-MoE tests).
1062 let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
1063 && moe_layer.experts.up_stacked.is_some()
1064 && moe_layer.experts.down_stacked.is_some();
1065
1066 // Fast path for decode (tokens=1): the stacked decode impl
1067 // writes the weighted-sum result *directly* into `residual` via
1068 // `weighted_sum_residual_stacked`, skipping the moe_out scratch
1069 // and the trailing `add_inplace`. Saves 1 dispatch per layer.
1070 // Prefill (m>1) and the per-expert fallback still go through
1071 // moe_out + add_inplace.
1072 let decode_fast_path = stacked_path_available && tokens == 1;
1073 // Cross-layer fusion: when on the decode fast path AND there is
1074 // a next layer, fold its leading rms_norm into this layer's
1075 // tail (`weighted_sum_residual_norm_stacked`). Returns whether
1076 // the fusion ran so the caller can signal the next layer to
1077 // skip its standalone rms_norm.
1078 let did_norm_fusion = decode_fast_path && next_layer_idx.is_some();
1079
1080 if stacked_path_available {
1081 if tokens > 1 {
1082 // Prefill: one batched 2-D mul_mm_id covers all
1083 // (token, expert) pairs in parallel.
1084 self.moe_forward_batched_prefill(ctx, li, tokens)?;
1085 } else {
1086 // Decode m=1: dedicated per-token path that fuses
1087 // residual-add into the final weighted-sum, and
1088 // optionally folds the next layer's rms_norm in too.
1089 self.moe_forward_stacked(ctx, li, tokens, residual, next_layer_idx)?;
1090 }
1091 } else {
1092 moe_forward::<B>(
1093 ctx,
1094 &self.scratch.norm_out,
1095 &self.scratch.router_logits,
1096 &mut self.scratch.moe_out,
1097 tokens,
1098 h,
1099 self.cfg.expert_intermediate_size,
1100 self.cfg.num_experts,
1101 self.cfg.num_experts_per_tok,
1102 self.cfg.norm_topk_prob,
1103 &moe_layer.experts,
1104 &mut self.scratch.x_single,
1105 &mut self.scratch.acc_buf,
1106 &mut self.scratch.gate_up_buf,
1107 &mut self.scratch.silu_buf,
1108 &mut self.scratch.down_buf,
1109 &self.scratch.zero_hidden,
1110 )?;
1111 }
1112
1113 // 12. residual += moe_out (skipped on decode fast path — already
1114 // accumulated by `weighted_sum_residual_stacked`).
1115 if !decode_fast_path {
1116 B::add_inplace(ctx, residual, &self.scratch.moe_out, tokens * h);
1117 }
1118
1119 if let Some(t0) = moe_t0 {
1120 B::sync(ctx);
1121 MOE_TIME_US.fetch_add(
1122 t0.elapsed().as_micros() as u64,
1123 std::sync::atomic::Ordering::Relaxed,
1124 );
1125 MOE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1126 }
1127
1128 Ok(did_norm_fusion)
1129 }
1130
1131 fn moe_forward_stacked(
1132 &mut self,
1133 ctx: &mut B::Context,
1134 li: usize,
1135 tokens: usize,
1136 residual: &mut B::Buffer,
1137 next_layer_idx: Option<usize>,
1138 ) -> Result<()> {
1139 let cfg = &self.cfg;
1140 // `next_norm_w` is the next layer's `attn_layer.input_ln_w`.
1141 // We can't borrow `self.attn_layers[idx]` and pass &mut
1142 // self.scratch to the impl simultaneously, so collect the raw
1143 // pointer here. Safety: forward_layer holds &mut self for the
1144 // call; the borrow scopes are fully sequential.
1145 let next_norm_w_ptr: Option<*const B::Buffer> =
1146 next_layer_idx.map(|idx| &self.attn_layers[idx].input_ln_w as *const _);
1147 // SAFETY: pointer dereference is valid because:
1148 // * The buffer lives in `self.attn_layers[idx]` which we
1149 // borrowed immutably to take the pointer. We do not mutate
1150 // `self.attn_layers` while `next_norm_w_ptr` is in use.
1151 // * `&mut self.scratch` and `&self.moe_layers[li]` are disjoint
1152 // fields from `self.attn_layers` so this is safe.
1153 let next_norm_w: Option<&B::Buffer> = next_norm_w_ptr.map(|p| unsafe { &*p });
1154 moe_forward_stacked_decode_impl::<B>(
1155 ctx,
1156 &self.moe_layers[li],
1157 &mut self.scratch,
1158 cfg.base.hidden_size,
1159 cfg.expert_intermediate_size,
1160 cfg.num_experts_per_tok,
1161 cfg.num_experts,
1162 cfg.norm_topk_prob,
1163 tokens,
1164 residual,
1165 next_norm_w,
1166 cfg.base.rms_norm_eps,
1167 )
1168 }
1169
1170 fn moe_forward_batched_prefill(
1171 &mut self,
1172 ctx: &mut B::Context,
1173 li: usize,
1174 tokens: usize,
1175 ) -> Result<()> {
1176 let cfg = &self.cfg;
1177 moe_forward_batched_prefill_impl::<B>(
1178 ctx,
1179 &self.moe_layers[li],
1180 &mut self.scratch,
1181 cfg.base.hidden_size,
1182 cfg.expert_intermediate_size,
1183 cfg.num_experts_per_tok,
1184 cfg.num_experts,
1185 cfg.norm_topk_prob,
1186 tokens,
1187 )
1188 }
1189
1190 /// Prefill: process `tokens` prompt tokens, return last-token logits.
1191 pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1192 let seq_len = tokens.len();
1193 assert!(seq_len > 0);
1194 self.ensure_scratch(seq_len);
1195 self.ensure_kv(cache_id);
1196
1197 let pos_offset = self
1198 .kv_caches
1199 .get(cache_id)
1200 .and_then(|layers| layers.first())
1201 .map(|c| c.len)
1202 .unwrap_or(0);
1203
1204 let h = self.cfg.base.hidden_size;
1205 let vocab = self.cfg.base.vocab_size;
1206 let mut ctx = B::new_context();
1207
1208 // FERRUM_DECODE_OP_PROFILE doubles as the prefill-profile gate
1209 // for Qwen3-MoE: when set, dump (attn-us, moe-us, total-us) at
1210 // the end of prefill so we can attribute the prefill bottleneck
1211 // between attention and MoE.
1212 let prefill_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1213 B::sync(&mut ctx);
1214 for c in [
1215 &ATTN_TIME_US,
1216 &ATTN_CALLS,
1217 &MOE_TIME_US,
1218 &MOE_CALLS,
1219 &MOE_PREFILL_HOST_TOPK_US,
1220 &MOE_PREFILL_HOST_TOPK_CALLS,
1221 &MOE_PREFILL_GATE_US,
1222 &MOE_PREFILL_GATE_CALLS,
1223 &MOE_PREFILL_UP_US,
1224 &MOE_PREFILL_UP_CALLS,
1225 &MOE_PREFILL_SILU_US,
1226 &MOE_PREFILL_SILU_CALLS,
1227 &MOE_PREFILL_DOWN_US,
1228 &MOE_PREFILL_DOWN_CALLS,
1229 &MOE_PREFILL_WSUM_US,
1230 &MOE_PREFILL_WSUM_CALLS,
1231 ] {
1232 c.store(0, std::sync::atomic::Ordering::Relaxed);
1233 }
1234 Some(std::time::Instant::now())
1235 } else {
1236 None
1237 };
1238
1239 let mut residual = self
1240 .scratch
1241 .residual
1242 .take()
1243 .expect("scratch residual missing (previous call didn't restore)");
1244 B::embedding_lookup(&mut ctx, &self.embed, tokens, &mut residual, h);
1245
1246 // For prefill (seq_len > 1) the cross-layer norm fusion does
1247 // not apply (it lives on the decode fast path). We still pass
1248 // `next_layer_idx = None` so forward_layer emits the regular
1249 // tail.
1250 let mut prev_did_norm_fusion = false;
1251 let num_layers = self.cfg.base.num_layers;
1252 for li in 0..num_layers {
1253 let next_layer_idx = if li + 1 < num_layers {
1254 Some(li + 1)
1255 } else {
1256 None
1257 };
1258 prev_did_norm_fusion = self
1259 .forward_layer(
1260 &mut ctx,
1261 li,
1262 cache_id,
1263 &mut residual,
1264 pos_offset,
1265 seq_len,
1266 next_layer_idx,
1267 prev_did_norm_fusion,
1268 )
1269 .expect("forward_layer");
1270 }
1271
1272 // Last-token slice → final RMSNorm → lm_head.
1273 B::copy_slice(
1274 &mut ctx,
1275 &residual,
1276 (seq_len - 1) * h,
1277 &mut self.scratch.last_hidden,
1278 0,
1279 h,
1280 );
1281 B::rms_norm(
1282 &mut ctx,
1283 &self.scratch.last_hidden,
1284 &self.final_norm_w,
1285 self.cfg.base.rms_norm_eps,
1286 &mut self.scratch.last_normed,
1287 1,
1288 h,
1289 );
1290 self.lm_head.forward(
1291 &mut ctx,
1292 &self.scratch.last_normed,
1293 &mut self.scratch.logits,
1294 1,
1295 );
1296
1297 B::sync(&mut ctx);
1298 if let Some(t0) = prefill_t0 {
1299 let total_us = t0.elapsed().as_micros() as u64;
1300 let attn_us = ATTN_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
1301 let attn_n = ATTN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1302 let moe_us = MOE_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
1303 let moe_n = MOE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1304 let other_us = total_us.saturating_sub(attn_us).saturating_sub(moe_us);
1305 eprintln!(
1306 "[prefill-profile] tokens={seq_len} total={} ms ({:.0} t/s)",
1307 total_us / 1000,
1308 seq_len as f64 * 1e6 / total_us as f64
1309 );
1310 let bucket = |label: &str, n: u64, us: u64| {
1311 if n > 0 {
1312 eprintln!(
1313 " {label:>6}: {:7} ms ({:5.1}%) over {n:4} calls",
1314 us / 1000,
1315 us as f64 * 100.0 / total_us as f64
1316 );
1317 }
1318 };
1319 bucket("attn", attn_n, attn_us);
1320 bucket("moe", moe_n, moe_us);
1321 bucket("other", 1, other_us);
1322 // MoE sub-stages — show as % of total prefill time so they
1323 // reconcile against the `moe` bucket above.
1324 let host_us = MOE_PREFILL_HOST_TOPK_US.load(std::sync::atomic::Ordering::Relaxed);
1325 let gate_us = MOE_PREFILL_GATE_US.load(std::sync::atomic::Ordering::Relaxed);
1326 let up_us = MOE_PREFILL_UP_US.load(std::sync::atomic::Ordering::Relaxed);
1327 let silu_us = MOE_PREFILL_SILU_US.load(std::sync::atomic::Ordering::Relaxed);
1328 let down_us = MOE_PREFILL_DOWN_US.load(std::sync::atomic::Ordering::Relaxed);
1329 let wsum_us = MOE_PREFILL_WSUM_US.load(std::sync::atomic::Ordering::Relaxed);
1330 let host_n = MOE_PREFILL_HOST_TOPK_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1331 let gate_n = MOE_PREFILL_GATE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1332 let up_n = MOE_PREFILL_UP_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1333 let silu_n = MOE_PREFILL_SILU_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1334 let down_n = MOE_PREFILL_DOWN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1335 let wsum_n = MOE_PREFILL_WSUM_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1336 bucket(" host", host_n, host_us);
1337 bucket(" gate", gate_n, gate_us);
1338 bucket(" up", up_n, up_us);
1339 bucket(" silu", silu_n, silu_us);
1340 bucket(" down", down_n, down_us);
1341 bucket(" wsum", wsum_n, wsum_us);
1342 }
1343 self.scratch.residual = Some(residual);
1344 B::to_vec(&self.scratch.logits, vocab)
1345 }
1346
1347 /// Decode: 1 token at position `pos`, return next-step logits.
1348 pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1349 self.ensure_scratch(1);
1350 self.ensure_kv(cache_id);
1351
1352 let h = self.cfg.base.hidden_size;
1353 let vocab = self.cfg.base.vocab_size;
1354 let mut ctx = B::new_context();
1355
1356 let decode_t0 = if std::env::var("FERRUM_MOE_PROFILE").is_ok() {
1357 Some(std::time::Instant::now())
1358 } else {
1359 None
1360 };
1361
1362 // FERRUM_DECODE_OP_PROFILE gates the per-stage breakdown emitted
1363 // at the bottom of every decode token. Reuses the same atomic
1364 // counters that `forward_layer` already populates (ATTN_TIME_US,
1365 // MOE_TIME_US — drained here per-token instead of per-prefill).
1366 let stage_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1367 B::sync(&mut ctx);
1368 for c in [
1369 &ATTN_TIME_US,
1370 &ATTN_CALLS,
1371 &MOE_TIME_US,
1372 &MOE_CALLS,
1373 &DEC_ROUTE_US,
1374 &DEC_GATE_US,
1375 &DEC_UP_US,
1376 &DEC_SILU_US,
1377 &DEC_DOWN_US,
1378 &DEC_WSUM_US,
1379 &DEC_EMBED_US,
1380 &DEC_FINAL_NORM_US,
1381 &DEC_LM_HEAD_US,
1382 ] {
1383 c.store(0, std::sync::atomic::Ordering::Relaxed);
1384 }
1385 Some(std::time::Instant::now())
1386 } else {
1387 None
1388 };
1389 let prof = stage_t0.is_some();
1390 let mark = |ctx: &mut B::Context, c: &AtomicU64, t0: std::time::Instant| {
1391 if prof {
1392 B::sync(ctx);
1393 c.fetch_add(
1394 t0.elapsed().as_micros() as u64,
1395 std::sync::atomic::Ordering::Relaxed,
1396 );
1397 }
1398 };
1399 let mt0 = std::time::Instant::now();
1400
1401 let mut residual = self
1402 .scratch
1403 .residual
1404 .take()
1405 .expect("scratch residual missing (previous call didn't restore)");
1406 let t0 = std::time::Instant::now();
1407 B::embedding_lookup(&mut ctx, &self.embed, &[token], &mut residual, h);
1408 mark(&mut ctx, &DEC_EMBED_US, t0);
1409 let _ = mt0; // silence if unused on non-profile builds
1410
1411 // Cross-layer rms_norm fusion: layer L's MoE tail folds the
1412 // next layer's leading rms_norm into its weighted-sum-residual
1413 // when the decode fast path applies. The flag carries forward.
1414 let mut prev_did_norm_fusion = false;
1415 let num_layers = self.cfg.base.num_layers;
1416 for li in 0..num_layers {
1417 let next_layer_idx = if li + 1 < num_layers {
1418 Some(li + 1)
1419 } else {
1420 None
1421 };
1422 prev_did_norm_fusion = self
1423 .forward_layer(
1424 &mut ctx,
1425 li,
1426 cache_id,
1427 &mut residual,
1428 pos as usize,
1429 1,
1430 next_layer_idx,
1431 prev_did_norm_fusion,
1432 )
1433 .expect("forward_layer");
1434 }
1435
1436 let t0 = std::time::Instant::now();
1437 B::rms_norm(
1438 &mut ctx,
1439 &residual,
1440 &self.final_norm_w,
1441 self.cfg.base.rms_norm_eps,
1442 &mut self.scratch.last_normed,
1443 1,
1444 h,
1445 );
1446 mark(&mut ctx, &DEC_FINAL_NORM_US, t0);
1447
1448 let t0 = std::time::Instant::now();
1449 self.lm_head.forward(
1450 &mut ctx,
1451 &self.scratch.last_normed,
1452 &mut self.scratch.logits,
1453 1,
1454 );
1455 mark(&mut ctx, &DEC_LM_HEAD_US, t0);
1456
1457 B::sync(&mut ctx);
1458 self.scratch.residual = Some(residual);
1459
1460 // FERRUM_DECODE_OP_PROFILE: per-token decode breakdown.
1461 if let Some(t0) = stage_t0 {
1462 use std::sync::atomic::Ordering;
1463 let total_us = t0.elapsed().as_micros() as u64;
1464 let attn_us = ATTN_TIME_US.swap(0, Ordering::Relaxed);
1465 let moe_us = MOE_TIME_US.swap(0, Ordering::Relaxed);
1466 let route = DEC_ROUTE_US.swap(0, Ordering::Relaxed);
1467 let gate = DEC_GATE_US.swap(0, Ordering::Relaxed);
1468 let up = DEC_UP_US.swap(0, Ordering::Relaxed);
1469 let silu = DEC_SILU_US.swap(0, Ordering::Relaxed);
1470 let down = DEC_DOWN_US.swap(0, Ordering::Relaxed);
1471 let wsum = DEC_WSUM_US.swap(0, Ordering::Relaxed);
1472 let embed = DEC_EMBED_US.swap(0, Ordering::Relaxed);
1473 let fnorm = DEC_FINAL_NORM_US.swap(0, Ordering::Relaxed);
1474 let lmhead = DEC_LM_HEAD_US.swap(0, Ordering::Relaxed);
1475 let other = total_us.saturating_sub(attn_us + moe_us + embed + fnorm + lmhead);
1476 let pct = |us: u64| -> f64 {
1477 if total_us == 0 {
1478 0.0
1479 } else {
1480 100.0 * us as f64 / total_us as f64
1481 }
1482 };
1483 eprintln!(
1484 "[decode-prof] total={} ms | attn={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | embed={} fnorm={} lmhead={} other={} ({:.1}%)",
1485 total_us / 1000,
1486 attn_us / 1000, pct(attn_us),
1487 moe_us / 1000, pct(moe_us),
1488 route / 1000, gate / 1000, up / 1000, silu / 1000, down / 1000, wsum / 1000,
1489 embed / 1000, fnorm / 1000, lmhead / 1000,
1490 other / 1000, pct(other),
1491 );
1492 }
1493
1494 // Drain MoE per-op counters every decode step. The counters
1495 // accumulate across all 48 layers; printing per-step gives a
1496 // per-token breakdown.
1497 if let Some(t0) = decode_t0 {
1498 use crate::moe::dispatch::*;
1499 use std::sync::atomic::Ordering;
1500 let total_us = t0.elapsed().as_micros() as u64;
1501 let sync_us = MOE_SYNC_US.swap(0, Ordering::Relaxed);
1502 let sync_n = MOE_SYNC_CALLS.swap(0, Ordering::Relaxed);
1503 let topk_us = MOE_HOST_TOPK_US.swap(0, Ordering::Relaxed);
1504 let topk_n = MOE_HOST_TOPK_CALLS.swap(0, Ordering::Relaxed);
1505 let gu_us = MOE_GEMV_GATE_UP_US.swap(0, Ordering::Relaxed);
1506 let gu_n = MOE_GEMV_GATE_UP_CALLS.swap(0, Ordering::Relaxed);
1507 let silu_us = MOE_SILU_US.swap(0, Ordering::Relaxed);
1508 let silu_n = MOE_SILU_CALLS.swap(0, Ordering::Relaxed);
1509 let dn_us = MOE_GEMV_DOWN_US.swap(0, Ordering::Relaxed);
1510 let dn_n = MOE_GEMV_DOWN_CALLS.swap(0, Ordering::Relaxed);
1511 let sa_us = MOE_SCALED_ADD_US.swap(0, Ordering::Relaxed);
1512 let sa_n = MOE_SCALED_ADD_CALLS.swap(0, Ordering::Relaxed);
1513 let cp_us = MOE_COPY_US.swap(0, Ordering::Relaxed);
1514 let cp_n = MOE_COPY_CALLS.swap(0, Ordering::Relaxed);
1515 eprintln!(
1516 "[moe-prof] decode total={} ms | sync={} ms ({}x) | host_topk={} ms ({}x) | gate_up={} ms ({}x) | silu={} ms ({}x) | down={} ms ({}x) | scaled_add={} ms ({}x) | copy={} ms ({}x)",
1517 total_us / 1000,
1518 sync_us / 1000, sync_n,
1519 topk_us / 1000, topk_n,
1520 gu_us / 1000, gu_n,
1521 silu_us / 1000, silu_n,
1522 dn_us / 1000, dn_n,
1523 sa_us / 1000, sa_n,
1524 cp_us / 1000, cp_n,
1525 );
1526 }
1527
1528 B::to_vec(&self.scratch.logits, vocab)
1529 }
1530
1531 /// Multi-sequence batched decode (Phase 4b for MoE).
1532 ///
1533 /// Mirrors `LlamaFamilyModel::decode_batch_internal` but adapted to
1534 /// the MoE forward. The wins come from running the GEMM-heavy ops
1535 /// (qkv_proj, o_proj, router, MoE expert mul_mm_id, lm_head) at
1536 /// m=M, even though attention stays a per-item loop because
1537 /// Qwen3-MoE uses contiguous KV — no paged path here.
1538 ///
1539 /// Cross-layer rms_norm fusion (the `weighted_sum_residual_norm_stacked`
1540 /// fast path) is disabled in batched mode: the prefill MoE path
1541 /// (`moe_forward_batched_prefill_impl`) writes to `moe_out` and we
1542 /// add to residual explicitly. Each layer's leading rms_norm runs
1543 /// at m=M, which is one fused dispatch on M rows — cheap.
1544 pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1545 let m = batch.len();
1546 if m == 0 {
1547 return Vec::new();
1548 }
1549 if m == 1 {
1550 let (cid, tok, pos) = &batch[0];
1551 return vec![self.decode_internal(cid, *tok, *pos)];
1552 }
1553
1554 let prof_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1555 Some(std::time::Instant::now())
1556 } else {
1557 None
1558 };
1559
1560 for (cid, _, _) in batch {
1561 self.ensure_kv(cid);
1562 }
1563 self.ensure_scratch(m);
1564 self.scratch.enable_batched_decode_scratch(&self.cfg);
1565
1566 let h = self.cfg.base.hidden_size;
1567 let vocab = self.cfg.base.vocab_size;
1568 let mut ctx = B::new_context();
1569
1570 // 0. Embed all M tokens into residual [M, H]
1571 let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1572 let mut residual = self
1573 .scratch
1574 .residual
1575 .take()
1576 .expect("scratch residual missing (previous call didn't restore)");
1577 B::embedding_lookup(&mut ctx, &self.embed, &tokens, &mut residual, h);
1578
1579 // 1..num_layers: batched forward for each layer
1580 for li in 0..self.cfg.base.num_layers {
1581 self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m)
1582 .expect("forward_layer_batched_decode");
1583 }
1584
1585 // Final RMSNorm on [M, H] → norm_out [M, H]
1586 B::rms_norm(
1587 &mut ctx,
1588 &residual,
1589 &self.final_norm_w,
1590 self.cfg.base.rms_norm_eps,
1591 &mut self.scratch.norm_out,
1592 m,
1593 h,
1594 );
1595
1596 // LM head with m=M → batch_logits [M, vocab]
1597 self.lm_head.forward(
1598 &mut ctx,
1599 &self.scratch.norm_out,
1600 &mut self.scratch.batch_logits,
1601 m,
1602 );
1603
1604 B::sync(&mut ctx);
1605 self.scratch.residual = Some(residual);
1606
1607 let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
1608
1609 // Profile dump (one decode_batch_internal call = one decode step
1610 // covering all m tokens).
1611 if let Some(t0) = prof_t0 {
1612 use std::sync::atomic::Ordering;
1613 let total_us = t0.elapsed().as_micros() as u64;
1614 let dense = BD_DENSE_US.swap(0, Ordering::Relaxed);
1615 let attn = BD_ATTN_PERITEM_US.swap(0, Ordering::Relaxed);
1616 let moe = BD_MOE_US.swap(0, Ordering::Relaxed);
1617 let layers = BD_LAYER_CALLS.swap(0, Ordering::Relaxed);
1618 let other = total_us.saturating_sub(dense + attn + moe);
1619 let pct = |us: u64| -> f64 {
1620 if total_us == 0 {
1621 0.0
1622 } else {
1623 100.0 * us as f64 / total_us as f64
1624 }
1625 };
1626 // MoE sub-stage breakdown — meaningful when
1627 // moe_forward_batched_decode_impl was used.
1628 let moe_route = MOE_BATCHED_DECODE_ROUTE_US.swap(0, Ordering::Relaxed);
1629 let moe_gate = MOE_BATCHED_DECODE_GATE_US.swap(0, Ordering::Relaxed);
1630 let moe_up = MOE_BATCHED_DECODE_UP_US.swap(0, Ordering::Relaxed);
1631 let moe_silu = MOE_BATCHED_DECODE_SILU_US.swap(0, Ordering::Relaxed);
1632 let moe_down = MOE_BATCHED_DECODE_DOWN_US.swap(0, Ordering::Relaxed);
1633 let moe_wsum = MOE_BATCHED_DECODE_WSUM_US.swap(0, Ordering::Relaxed);
1634 eprintln!(
1635 "[batched-decode-prof] m={} layers={} total={} ms | dense={} ({:.1}%) | attn_peritem={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | other={} ({:.1}%)",
1636 m, layers, total_us / 1000,
1637 dense / 1000, pct(dense),
1638 attn / 1000, pct(attn),
1639 moe / 1000, pct(moe),
1640 moe_route / 1000, moe_gate / 1000, moe_up / 1000,
1641 moe_silu / 1000, moe_down / 1000, moe_wsum / 1000,
1642 other / 1000, pct(other),
1643 );
1644 }
1645
1646 (0..m)
1647 .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
1648 .collect()
1649 }
1650
1651 /// One transformer layer over M items: GEMMs at m=M, per-item
1652 /// attention loop, MoE FFN at m=M via the prefill batched path.
1653 /// Mirrors `LlamaFamilyModel::forward_layer_batched_decode` minus
1654 /// the paged branch.
1655 fn forward_layer_batched_decode(
1656 &mut self,
1657 ctx: &mut B::Context,
1658 li: usize,
1659 batch: &[(String, u32, u32)],
1660 residual: &mut B::Buffer,
1661 m: usize,
1662 ) -> Result<()> {
1663 let cfg_base = &self.cfg.base;
1664 let h = cfg_base.hidden_size;
1665 let nh = cfg_base.num_heads;
1666 let nkv = cfg_base.num_kv_heads;
1667 let hd = cfg_base.head_dim;
1668 let eps = cfg_base.rms_norm_eps;
1669 let q_dim = nh * hd;
1670 let kv_dim = nkv * hd;
1671
1672 let attn_layer = &self.attn_layers[li];
1673 let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
1674 let dummy_w = &attn_layer.input_ln_w;
1675 let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy_w);
1676 let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy_w);
1677
1678 let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
1679 let stage_t0 = || -> Option<std::time::Instant> {
1680 if prof {
1681 Some(std::time::Instant::now())
1682 } else {
1683 None
1684 }
1685 };
1686 let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
1687 if let Some(t) = t0 {
1688 B::sync(ctx);
1689 c.fetch_add(
1690 t.elapsed().as_micros() as u64,
1691 std::sync::atomic::Ordering::Relaxed,
1692 );
1693 }
1694 };
1695 if prof {
1696 BD_LAYER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1697 }
1698
1699 let dense_t0 = stage_t0();
1700
1701 // 1. rms_norm [M, H] → norm_out
1702 B::rms_norm(
1703 ctx,
1704 residual,
1705 &attn_layer.input_ln_w,
1706 eps,
1707 &mut self.scratch.norm_out,
1708 m,
1709 h,
1710 );
1711
1712 // 2. qkv_proj GEMM at m=M: norm_out [M, H] → qkv_out [M, QKV]
1713 attn_layer
1714 .qkv_proj
1715 .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
1716
1717 // ── Paged batched attention path ───────────────────────────────
1718 //
1719 // Mirrors LlamaFamilyModel's Phase 4b paged batched-decode. When
1720 // `FERRUM_METAL_PAGED_KV=1` was set at ensure_kv time, each
1721 // cache_id has paged metadata (block_table + context_lens) and
1722 // K/V live in the shared `paged_pools[layer]` pool. This path:
1723 // 1. m × `split_qkv_norm_rope_into_paged_cache` writes K/V into
1724 // the pool at each item's allocated blocks AND fills
1725 // `paged_batch_q[i*q_dim ..]` with that item's head-major Q.
1726 // 2. Build `paged_batch_block_tables [m, max_blocks_per_seq]`
1727 // and `paged_batch_context_lens [m]` host-side, upload.
1728 // 3. ONE `paged_decode_attention(num_seqs=m)` call reads all m
1729 // sequences' K/V from the pool via per-seq block_tables,
1730 // writes outputs to `paged_batch_o [m, q_dim]`.
1731 // 4. Per-item copy_slice paged_batch_o[i] → attn_flat[i*q_dim].
1732 //
1733 // This is the structural fix for the c=16 attn_peritem cliff
1734 // (~55 ms / round of 16 sequential m=1 flash_attn + plumbing).
1735 let is_paged = self.paged_pools.is_some();
1736 if is_paged {
1737 stage_end(dense_t0, ctx, &BD_DENSE_US);
1738 let attn_t0 = stage_t0();
1739
1740 let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
1741 let block_size = 16; // matches PAGED_BLOCK_SIZE in ensure_kv
1742 let qkv_stride = q_dim + 2 * kv_dim;
1743
1744 // Step 1: per-item paged write. Read each item's qkv out of
1745 // the batched qkv_out buffer; write its head-major Q into
1746 // paged_batch_q[i*q_dim ..]; write its K/V into the pool at
1747 // its allocated blocks via block_table.
1748 let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
1749 let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
1750 let pool_ptr = {
1751 let pools = self.paged_pools.as_mut().unwrap();
1752 (
1753 &mut pools[li].0 as *mut B::Buffer,
1754 &mut pools[li].1 as *mut B::Buffer,
1755 )
1756 };
1757 // SAFETY: pools allocated-once, see paged_pools field comment.
1758 let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
1759 for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1760 let pos_i = *pos as usize;
1761 let caches = self
1762 .kv_caches
1763 .get(cache_id)
1764 .expect("paged batched: cache not present");
1765 let cache = &caches[li];
1766 let bt = cache
1767 .block_table
1768 .as_ref()
1769 .expect("paged batched: block_table missing");
1770 let cache_len_before = cache.len;
1771 let bt_ptr = bt as *const B::Buffer;
1772 // SAFETY: bt is read-only during the dispatch; we don't
1773 // touch self.kv_caches between this raw deref and the
1774 // call below.
1775 let bt_safe: &B::Buffer = unsafe { &*bt_ptr };
1776 B::split_qkv_norm_rope_into_paged_cache(
1777 ctx,
1778 &self.scratch.qkv_out,
1779 (i as u64) * qkv_stride_bytes,
1780 q_norm_w,
1781 k_norm_w,
1782 &self.rope.cos,
1783 &self.rope.sin,
1784 self.scratch
1785 .paged_batch_q
1786 .as_mut()
1787 .expect("paged_batch_q missing"),
1788 (i as u64) * q_head_major_size_bytes,
1789 pool_k,
1790 pool_v,
1791 bt_safe,
1792 1,
1793 nh,
1794 nkv,
1795 hd,
1796 pos_i,
1797 eps,
1798 qk_mode,
1799 cache_len_before,
1800 block_size,
1801 max_blocks_per_seq,
1802 )
1803 .expect("split_qkv_norm_rope_into_paged_cache (batched)");
1804 }
1805
1806 // Step 2: bump cache.len and stack block_tables + context_lens
1807 // host-side, then upload to device scratch.
1808 let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
1809 let mut stacked_cl: Vec<u32> = vec![0u32; m];
1810 for (i, (cache_id, _, _)) in batch.iter().enumerate() {
1811 let caches = self
1812 .kv_caches
1813 .get_mut(cache_id)
1814 .expect("paged batched: cache not present");
1815 let cache = &mut caches[li];
1816 cache.len += 1;
1817 let len = cache.len as u32;
1818 stacked_cl[i] = len;
1819 let blocks = &cache.paged_block_indices;
1820 let n_to_copy = blocks.len().min(max_blocks_per_seq);
1821 stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
1822 .copy_from_slice(&blocks[..n_to_copy]);
1823 }
1824 let bt_buf = self
1825 .scratch
1826 .paged_batch_block_tables
1827 .as_mut()
1828 .expect("paged_batch_block_tables missing");
1829 B::write_u32(ctx, bt_buf, &stacked_bt);
1830 let cl_buf = self
1831 .scratch
1832 .paged_batch_context_lens
1833 .as_mut()
1834 .expect("paged_batch_context_lens missing");
1835 B::write_u32(ctx, cl_buf, &stacked_cl);
1836
1837 // Step 3: one batched paged_decode_attention(num_seqs=m).
1838 let bt_ptr =
1839 self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
1840 let cl_ptr =
1841 self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
1842 let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
1843 let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
1844 // SAFETY: scratch buffers are not aliased; we hold &mut self
1845 // through this entire block.
1846 let bt_safe = unsafe { &*bt_ptr };
1847 let cl_safe = unsafe { &*cl_ptr };
1848 let q_safe = unsafe { &*q_ptr };
1849 let o_safe = unsafe { &mut *o_ptr };
1850 B::paged_decode_attention(
1851 ctx,
1852 q_safe,
1853 pool_k,
1854 pool_v,
1855 o_safe,
1856 bt_safe,
1857 cl_safe,
1858 m,
1859 nh,
1860 nkv,
1861 hd,
1862 block_size,
1863 max_blocks_per_seq,
1864 1, // q_len
1865 )
1866 .expect("paged batched decode");
1867
1868 // Step 4: per-item copy paged_batch_o[i] → attn_flat[i*q_dim].
1869 for i in 0..m {
1870 B::copy_slice(
1871 ctx,
1872 self.scratch.paged_batch_o.as_ref().unwrap(),
1873 i * q_dim,
1874 &mut self.scratch.attn_flat,
1875 i * q_dim,
1876 q_dim,
1877 );
1878 }
1879
1880 stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
1881 } else {
1882 // 3. split_qkv [M, QKV] → q_buf [M, Q], k_buf [M, KV], v_buf [M, KV]
1883 B::split_qkv(
1884 ctx,
1885 &self.scratch.qkv_out,
1886 &mut self.scratch.q_buf,
1887 &mut self.scratch.k_buf,
1888 &mut self.scratch.v_buf,
1889 m,
1890 q_dim,
1891 kv_dim,
1892 );
1893
1894 // 4-6. Per-item loop: rope + kv_append + attention.
1895 // Each item has its own cache_id + pos + kv_len.
1896 let q_single = self
1897 .scratch
1898 .q_single
1899 .as_ref()
1900 .expect("q_single missing — enable_batched_decode_scratch not called")
1901 as *const B::Buffer;
1902 let k_single =
1903 self.scratch.k_single.as_ref().expect("k_single missing") as *const B::Buffer;
1904 let v_single =
1905 self.scratch.v_single.as_ref().expect("v_single missing") as *const B::Buffer;
1906 let q_hm_single =
1907 self.scratch
1908 .q_head_major_single
1909 .as_mut()
1910 .expect("q_head_major_single missing") as *mut B::Buffer;
1911 let k_hm_single =
1912 self.scratch
1913 .k_head_major_single
1914 .as_mut()
1915 .expect("k_head_major_single missing") as *mut B::Buffer;
1916 let v_hm_single =
1917 self.scratch
1918 .v_head_major_single
1919 .as_mut()
1920 .expect("v_head_major_single missing") as *mut B::Buffer;
1921 let attn_hm_single =
1922 self.scratch
1923 .attn_head_major_single
1924 .as_mut()
1925 .expect("attn_head_major_single missing") as *mut B::Buffer;
1926 // SAFETY: each Option holds a stable B::Buffer; we don't mutate
1927 // self.scratch in a way that would invalidate them inside the loop
1928 // (the kv_caches mutation is on a disjoint field).
1929
1930 // End of dense block (rms_norm + qkv_proj + split_qkv); start
1931 // per-item attention loop instrumentation.
1932 stage_end(dense_t0, ctx, &BD_DENSE_US);
1933 let attn_t0 = stage_t0();
1934
1935 for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1936 let pos_i = *pos as usize;
1937
1938 // SAFETY: borrows of disjoint scratch fields, see above.
1939 let q_single_ref = unsafe { &*q_single };
1940 let k_single_ref = unsafe { &*k_single };
1941 let v_single_ref = unsafe { &*v_single };
1942 let q_hm_single_mut = unsafe { &mut *q_hm_single };
1943 let k_hm_single_mut = unsafe { &mut *k_hm_single };
1944 let v_hm_single_mut = unsafe { &mut *v_hm_single };
1945 let attn_hm_single_mut = unsafe { &mut *attn_hm_single };
1946
1947 // Extract item i's Q/K/V slice from the batched buffers.
1948 B::copy_slice(
1949 ctx,
1950 &self.scratch.q_buf,
1951 i * q_dim,
1952 // copy_slice signature wants &mut for dst, but q_single
1953 // is shared; we need a *mut variant — since enable_*
1954 // gives us Option, we can do as_mut() here.
1955 self.scratch.q_single.as_mut().unwrap(),
1956 0,
1957 q_dim,
1958 );
1959 B::copy_slice(
1960 ctx,
1961 &self.scratch.k_buf,
1962 i * kv_dim,
1963 self.scratch.k_single.as_mut().unwrap(),
1964 0,
1965 kv_dim,
1966 );
1967 B::copy_slice(
1968 ctx,
1969 &self.scratch.v_buf,
1970 i * kv_dim,
1971 self.scratch.v_single.as_mut().unwrap(),
1972 0,
1973 kv_dim,
1974 );
1975
1976 // qk_norm_rope with tokens=1, per-item pos.
1977 B::qk_norm_rope(
1978 ctx,
1979 q_single_ref,
1980 q_norm_w,
1981 &self.rope.cos,
1982 &self.rope.sin,
1983 q_hm_single_mut,
1984 1,
1985 nh,
1986 hd,
1987 pos_i,
1988 eps,
1989 qk_mode,
1990 );
1991 B::qk_norm_rope(
1992 ctx,
1993 k_single_ref,
1994 k_norm_w,
1995 &self.rope.cos,
1996 &self.rope.sin,
1997 k_hm_single_mut,
1998 1,
1999 nkv,
2000 hd,
2001 pos_i,
2002 eps,
2003 qk_mode,
2004 );
2005 B::qk_norm_rope(
2006 ctx,
2007 v_single_ref,
2008 dummy_w,
2009 &self.rope.cos,
2010 &self.rope.sin,
2011 v_hm_single_mut,
2012 1,
2013 nkv,
2014 hd,
2015 pos_i,
2016 eps,
2017 0,
2018 );
2019
2020 // KV append + attention for item i's cache.
2021 let caches = self
2022 .kv_caches
2023 .get_mut(cache_id)
2024 .expect("ensure_kv must be called before forward_layer_batched");
2025 let cache = &mut caches[li];
2026 B::kv_cache_append_head_major(
2027 ctx,
2028 &mut cache.k,
2029 &mut cache.v,
2030 cache.len,
2031 cache.capacity,
2032 k_hm_single_mut,
2033 v_hm_single_mut,
2034 1,
2035 nkv,
2036 hd,
2037 );
2038 cache.len += 1;
2039 let kv_len = cache.len;
2040 let kv_stride = cache.capacity;
2041
2042 let attn_cfg = ferrum_kernels::backend::AttnConfig {
2043 num_heads: nh,
2044 num_kv_heads: nkv,
2045 head_dim: hd,
2046 causal: true,
2047 scale: 1.0 / (hd as f32).sqrt(),
2048 kv_seq_stride: kv_stride,
2049 sliding_window: cfg_base.sliding_window,
2050 };
2051 B::flash_attention(
2052 ctx,
2053 q_hm_single_mut,
2054 &cache.k,
2055 &cache.v,
2056 attn_hm_single_mut,
2057 1,
2058 1,
2059 kv_len,
2060 pos_i,
2061 &attn_cfg,
2062 );
2063
2064 // Untranspose head-major → token-major: for tokens=1 the
2065 // layouts are byte-identical, so copy_slice straight into
2066 // attn_flat at the per-item offset (saves a transpose).
2067 B::copy_slice(
2068 ctx,
2069 attn_hm_single_mut,
2070 0,
2071 &mut self.scratch.attn_flat,
2072 i * q_dim,
2073 q_dim,
2074 );
2075 }
2076
2077 // End of per-item attention loop.
2078 stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
2079 } // end of `else` for non-paged path
2080
2081 let post_attn_t0 = stage_t0();
2082
2083 // 7. o_proj GEMM at m=M: attn_flat [M, Q] → o_proj_out [M, H]
2084 attn_layer.o_proj.forward(
2085 ctx,
2086 &self.scratch.attn_flat,
2087 &mut self.scratch.o_proj_out,
2088 m,
2089 );
2090
2091 // 8. fused residual_add + post_attention_layernorm
2092 B::fused_add_rms_norm(
2093 ctx,
2094 residual,
2095 &self.scratch.o_proj_out,
2096 &attn_layer.post_ln_w,
2097 eps,
2098 &mut self.scratch.norm_out,
2099 m,
2100 h,
2101 );
2102
2103 // o_proj + post-norm count under DENSE.
2104 stage_end(post_attn_t0, ctx, &BD_DENSE_US);
2105 let moe_t0 = stage_t0();
2106
2107 // 9. Router gemv: norm_out [M, H] → router_logits [M, n_exp]
2108 let moe_layer = &self.moe_layers[li];
2109 moe_layer.router.forward(
2110 ctx,
2111 &self.scratch.norm_out,
2112 &mut self.scratch.router_logits,
2113 m,
2114 );
2115
2116 // 10. MoE expert dispatch — per-item loop using the cheap
2117 // stacked decode kernels (gemv_quant_moe_id + silu_mul_stacked
2118 // + weighted_sum_batched). NOT the batched prefill path:
2119 // `moe_forward_batched_prefill_impl` is tuned for large M
2120 // (prefill) and the GPU bucketing overhead
2121 // (`compute_ids_tpe_gpu` + indirect-dispatch arg-buffer
2122 // setup) costs more than M sequential gemv calls at small M.
2123 //
2124 // Strategy: route ALL M tokens once via batched
2125 // `route_topk_softmax`, then loop M iterations of the stacked
2126 // decode kernels. Each iteration:
2127 // - extract item i's selected ids + weights from the M-batch
2128 // buffers via copy_slice
2129 // - copy norm_out[i*h..(i+1)*h] → x_single
2130 // - 3× gemv_quant_moe_id (gate/up/down) reading from x_single
2131 // - silu_mul_stacked
2132 // - weighted_sum_batched(batch=1) → acc_buf (fresh write,
2133 // no residual fusion)
2134 // - copy_slice acc_buf → moe_out[i*h..(i+1)*h]
2135 // After the loop, single add_inplace residual += moe_out [M, H].
2136 let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
2137 && moe_layer.experts.up_stacked.is_some()
2138 && moe_layer.experts.down_stacked.is_some();
2139 // MoE FFN dispatch tiers (m = batch size of this layer call):
2140 //
2141 // m = 1 : `moe_forward_stacked_decode_impl`
2142 // (decode m=1 fast path, fused gate+up+silu)
2143 // m ≥ 8 (default): `moe_forward_batched_prefill_impl`
2144 // (GEMM with simdgroup_matmul + GPU bucketing)
2145 // else (m=2..7) : per-item stacked decode loop
2146 //
2147 // EXPERIMENTAL — opt-in `FERRUM_MOE_BATCHED_DECODE=1` engages the
2148 // new `moe_forward_batched_decode_impl` for 2 ≤ m < 32. The
2149 // kernel itself is bitwise correct and ports llama.cpp's
2150 // `kernel_mul_mv_id` strategy to ferrum (one indirect-dispatch
2151 // GEMV per linear covering all m*top_k pairs). Empirically OFF
2152 // by default because the existing `forward_layer_batched_decode`
2153 // attention plumbing (per-item copy_slice × m × 6 dispatches)
2154 // scales linearly with m and overshadows the FFN savings —
2155 // regression measured at -19% (c=4) and -36% (c=16) on
2156 // Qwen3-30B-A3B Q4_K_M / M1 Max. Closing that gap requires a
2157 // batched attention path with offset-aware QKV slicing, which
2158 // is the next PR's job. Until then the kernel sits as
2159 // infrastructure.
2160 // Two independent thresholds:
2161 // * `FERRUM_MOE_BATCH_THRESHOLD` (default 8) — m above which
2162 // the LEGACY non-experimental path uses the prefill GEMM.
2163 // Shared with `decode_batch`'s engine-level gate, so users
2164 // who set it to a small value to engage batched decode
2165 // don't accidentally also push the inner FFN to GEMM.
2166 // * `FERRUM_MOE_PREFILL_THRESHOLD` (default 32) — m above
2167 // which the EXPERIMENTAL batched-decode path defers to the
2168 // prefill GEMM path. Mirrors llama.cpp's `ne21_mm_id_min=32`
2169 // GEMV→GEMM boundary.
2170 let legacy_prefill_threshold: usize = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
2171 .ok()
2172 .and_then(|s| s.parse().ok())
2173 .unwrap_or(8);
2174 let new_prefill_threshold: usize = std::env::var("FERRUM_MOE_PREFILL_THRESHOLD")
2175 .ok()
2176 .and_then(|s| s.parse().ok())
2177 .unwrap_or(32);
2178 let new_batched_enabled = stacked_path_available
2179 && B::supports_batched_moe_gemv()
2180 && std::env::var("FERRUM_MOE_BATCHED_DECODE").as_deref() == Ok("1");
2181
2182 // When the new path is opted in, it owns the m=2..new_prefill_threshold
2183 // range; the legacy threshold is overridden upward.
2184 let use_prefill_batched = if new_batched_enabled {
2185 stacked_path_available && m >= new_prefill_threshold
2186 } else {
2187 stacked_path_available && m >= legacy_prefill_threshold
2188 };
2189 let use_batched_decode = new_batched_enabled && !use_prefill_batched && m >= 2;
2190
2191 if use_prefill_batched {
2192 moe_forward_batched_prefill_impl::<B>(
2193 ctx,
2194 moe_layer,
2195 &mut self.scratch,
2196 h,
2197 self.cfg.expert_intermediate_size,
2198 self.cfg.num_experts_per_tok,
2199 self.cfg.num_experts,
2200 self.cfg.norm_topk_prob,
2201 m,
2202 )?;
2203 } else if use_batched_decode {
2204 moe_forward_batched_decode_impl::<B>(
2205 ctx,
2206 moe_layer,
2207 &mut self.scratch,
2208 h,
2209 self.cfg.expert_intermediate_size,
2210 self.cfg.num_experts_per_tok,
2211 self.cfg.num_experts,
2212 self.cfg.norm_topk_prob,
2213 m,
2214 )?;
2215 } else if stacked_path_available {
2216 let inter = self.cfg.expert_intermediate_size;
2217 let top_k = self.cfg.num_experts_per_tok;
2218 let n_exp = self.cfg.num_experts;
2219 let norm_topk_prob = self.cfg.norm_topk_prob;
2220 let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2221 let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2222 let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2223
2224 // Single batched router pass: writes selected_ids_buf [M, top_k]
2225 // and weights_2d [M, top_k]. Replaces M individual route calls.
2226 B::route_topk_softmax(
2227 ctx,
2228 &self.scratch.router_logits,
2229 &mut self.scratch.selected_ids_buf,
2230 &mut self.scratch.weights_2d,
2231 m,
2232 n_exp,
2233 top_k,
2234 norm_topk_prob,
2235 )?;
2236
2237 // Per-item loop using offset-aware kernel APIs — eliminates
2238 // the 4 copy_slice round-trips per iteration that the
2239 // earlier implementation needed (ids, weights, x_single,
2240 // moe_out). At c=16 / 48 layers that's ~3,072 dispatches
2241 // saved per token. Uses `gemv_quant_moe_id_offset` to read
2242 // `selected_ids_buf` at the i-th `top_k` block and
2243 // `norm_out` at the i-th hidden row directly. Falls back
2244 // to copy_slice path if backend doesn't support offsets.
2245 for i in 0..m {
2246 let ids_offset = i * top_k;
2247 let activation_offset = i * h;
2248 let weights_offset = i * top_k;
2249 let moe_out_offset = i * h;
2250
2251 // Stacked gate / up gemvs — broadcast item i's row of
2252 // norm_out across top_k slots, read item i's ids.
2253 let gate_res = B::gemv_quant_moe_id_offset(
2254 ctx,
2255 &self.scratch.norm_out,
2256 activation_offset,
2257 gate_stacked,
2258 &self.scratch.selected_ids_buf,
2259 ids_offset,
2260 &mut self.scratch.gate_out_stacked,
2261 top_k,
2262 0,
2263 );
2264 if gate_res.is_err() {
2265 // Backend doesn't support offset variants — fall back
2266 // to the legacy copy_slice path. Same as before.
2267 B::copy_slice(
2268 ctx,
2269 &self.scratch.selected_ids_buf,
2270 ids_offset,
2271 &mut self.scratch.ids_buf,
2272 0,
2273 top_k,
2274 );
2275 B::copy_slice(
2276 ctx,
2277 &self.scratch.weights_2d,
2278 weights_offset,
2279 &mut self.scratch.weights_buf,
2280 0,
2281 top_k,
2282 );
2283 B::copy_slice(
2284 ctx,
2285 &self.scratch.norm_out,
2286 activation_offset,
2287 &mut self.scratch.x_single,
2288 0,
2289 h,
2290 );
2291 B::gemv_quant_moe_id(
2292 ctx,
2293 &self.scratch.x_single,
2294 gate_stacked,
2295 &self.scratch.ids_buf,
2296 &mut self.scratch.gate_out_stacked,
2297 top_k,
2298 0,
2299 )?;
2300 B::gemv_quant_moe_id(
2301 ctx,
2302 &self.scratch.x_single,
2303 up_stacked,
2304 &self.scratch.ids_buf,
2305 &mut self.scratch.up_out_stacked,
2306 top_k,
2307 0,
2308 )?;
2309 B::silu_mul_stacked(
2310 ctx,
2311 &self.scratch.gate_out_stacked,
2312 &self.scratch.up_out_stacked,
2313 &mut self.scratch.silu_stacked,
2314 top_k,
2315 inter,
2316 )?;
2317 B::gemv_quant_moe_id(
2318 ctx,
2319 &self.scratch.silu_stacked,
2320 down_stacked,
2321 &self.scratch.ids_buf,
2322 &mut self.scratch.down_out_stacked,
2323 top_k,
2324 inter,
2325 )?;
2326 B::weighted_sum_batched(
2327 ctx,
2328 &self.scratch.down_out_stacked,
2329 &self.scratch.weights_buf,
2330 &mut self.scratch.acc_buf,
2331 1,
2332 top_k,
2333 h,
2334 )?;
2335 B::copy_slice(
2336 ctx,
2337 &self.scratch.acc_buf,
2338 0,
2339 &mut self.scratch.moe_out,
2340 moe_out_offset,
2341 h,
2342 );
2343 continue;
2344 }
2345 // Fast path: offset-aware all the way through.
2346 B::gemv_quant_moe_id_offset(
2347 ctx,
2348 &self.scratch.norm_out,
2349 activation_offset,
2350 up_stacked,
2351 &self.scratch.selected_ids_buf,
2352 ids_offset,
2353 &mut self.scratch.up_out_stacked,
2354 top_k,
2355 0,
2356 )?;
2357 B::silu_mul_stacked(
2358 ctx,
2359 &self.scratch.gate_out_stacked,
2360 &self.scratch.up_out_stacked,
2361 &mut self.scratch.silu_stacked,
2362 top_k,
2363 inter,
2364 )?;
2365 B::gemv_quant_moe_id_offset(
2366 ctx,
2367 &self.scratch.silu_stacked,
2368 0, // silu_stacked itself stays at offset 0 each iter
2369 down_stacked,
2370 &self.scratch.selected_ids_buf,
2371 ids_offset,
2372 &mut self.scratch.down_out_stacked,
2373 top_k,
2374 inter,
2375 )?;
2376 // Write directly into moe_out at the per-item offset —
2377 // skips the copy_slice from acc_buf.
2378 B::weighted_sum_batched_offset(
2379 ctx,
2380 &self.scratch.down_out_stacked,
2381 &self.scratch.weights_2d,
2382 weights_offset,
2383 &mut self.scratch.moe_out,
2384 moe_out_offset,
2385 1,
2386 top_k,
2387 h,
2388 )?;
2389 }
2390 } else {
2391 // Backend without stacked variants — fall back to the legacy
2392 // per-(token, expert) host-routed path.
2393 moe_forward::<B>(
2394 ctx,
2395 &self.scratch.norm_out,
2396 &self.scratch.router_logits,
2397 &mut self.scratch.moe_out,
2398 m,
2399 h,
2400 self.cfg.expert_intermediate_size,
2401 self.cfg.num_experts,
2402 self.cfg.num_experts_per_tok,
2403 self.cfg.norm_topk_prob,
2404 &moe_layer.experts,
2405 &mut self.scratch.x_single,
2406 &mut self.scratch.acc_buf,
2407 &mut self.scratch.gate_up_buf,
2408 &mut self.scratch.silu_buf,
2409 &mut self.scratch.down_buf,
2410 &self.scratch.zero_hidden,
2411 )?;
2412 }
2413
2414 // 11. residual += moe_out [M, H]
2415 B::add_inplace(ctx, residual, &self.scratch.moe_out, m * h);
2416
2417 // Close MoE-block instrumentation (router + FFN + residual add).
2418 stage_end(moe_t0, ctx, &BD_MOE_US);
2419
2420 Ok(())
2421 }
2422}
2423
2424impl<B: Backend> DecoderOnlyLLM for Qwen3MoeModel<B> {
2425 fn config(&self) -> &LlmRuntimeConfig {
2426 &self.runtime_cfg
2427 }
2428
2429 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2430 // Eager scratch + KV cache grow + a 1-token forward warmup so
2431 // the first real prefill / decode doesn't pay the cold-start
2432 // ~25-MTLBuffer scratch alloc + ~96-MTLBuffer KV alloc + Metal
2433 // pipeline-state first-bind costs (~265 ms total on Qwen3-MoE
2434 // 30B-A3B / M1 Max). Mirrors what llama-bench's --warmup does
2435 // (which runs a same-shape forward before the timer).
2436 self.ensure_scratch(max_tokens);
2437 self.ensure_kv(cache_id);
2438
2439 // Warmup forward through all 48 layers under a scratch cache_id
2440 // so the real `cache_id` starts at pos_offset=0. Token 0 is
2441 // valid for any tokenizer (BOS or pad).
2442 const WARMUP_CACHE: &str = "__ferrum_warmup__";
2443 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2444 // Drop the warmup KV cache slot — real cache_id is unaffected.
2445 if let Some(caches) = self.kv_caches.remove(WARMUP_CACHE) {
2446 self.kv_free_pool.push(caches);
2447 }
2448 }
2449
2450 fn kv_capacity(&self) -> usize {
2451 // Mirror the bound `ensure_kv` will use when allocating the cache.
2452 let model_max = self.cfg.base.max_seq_len;
2453 const DEFAULT_KV_CAPACITY: usize = 4096;
2454 std::env::var("FERRUM_KV_CAPACITY")
2455 .ok()
2456 .and_then(|s| s.parse::<usize>().ok())
2457 .map(|cap| cap.min(model_max))
2458 .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
2459 }
2460
2461 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2462 self.prefill_internal(cache_id, tokens)
2463 }
2464
2465 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2466 self.decode_internal(cache_id, token, pos)
2467 }
2468
2469 // decode_batch is gated to use the batched path only when it's a
2470 // measurable win. The crossover depends on M:
2471 //
2472 // - At low M (≤ ~8) the per-item `decode_internal` loop wins
2473 // because: (a) it stays at scratch offset 0 (no copy_slice
2474 // overhead), (b) it preserves the cross-layer rms_norm fusion
2475 // fast path (`weighted_sum_residual_norm_stacked`).
2476 // - At high M (≥ ~12) the batched path wins because the dense
2477 // GEMM batching (qkv_proj, o_proj, router, lm_head at m=M) and
2478 // the prefill-batched MoE dispatch (one `gemm_quant_moe_id` for
2479 // all tokens) amortise the ~48-dispatch lost-fusion penalty.
2480 //
2481 // Default opted out of FERRUM_MOE_BATCHED. When opted in, the
2482 // batched path engages only at M ≥ FERRUM_MOE_BATCH_THRESHOLD
2483 // (default 12). Below that we still go per-item.
2484 //
2485 // Empirical note 2026-05-02: a follow-up PR added a batched MoE
2486 // GEMV kernel (`gemv_quant_moe_id_batched`) that holds MoE
2487 // dispatch count flat as concurrency scales. Wiring it through
2488 // `decode_batch_internal` regressed throughput by 19% (c=4) /
2489 // 36% (c=16) — `forward_layer_batched_decode`'s per-item
2490 // attention plumbing (copy_slice × m × 6 dispatches) costs more
2491 // than the MoE save. The batched MoE kernel is shipped as opt-in
2492 // infrastructure (`FERRUM_MOE_BATCHED_DECODE=1`); flipping it on
2493 // by default has to wait until the attention plumbing is fixed.
2494 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2495 let m = batch.len();
2496 let opted_in = std::env::var("FERRUM_MOE_BATCHED").as_deref() == Ok("1");
2497 let threshold = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
2498 .ok()
2499 .and_then(|s| s.parse::<usize>().ok())
2500 .unwrap_or(12);
2501 if opted_in && m >= threshold {
2502 self.decode_batch_internal(batch)
2503 } else {
2504 batch
2505 .iter()
2506 .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
2507 .collect()
2508 }
2509 }
2510
2511 fn release(&mut self, cache_id: &str) {
2512 let mut ctx = B::new_context();
2513 B::sync(&mut ctx);
2514 B::reset_graph(&mut ctx);
2515 B::sync(&mut ctx);
2516 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2517 // Paged mode: return the cache_id's blocks to the shared
2518 // allocator so other sequences can reuse them. Without this,
2519 // every request consumes max_blocks_per_seq blocks
2520 // permanently — pool exhausts after FERRUM_PAGED_MAX_SEQS
2521 // requests and subsequent ensure_kv panics with
2522 // "scratch residual missing" (the cascade panic from a
2523 // failed ensure_kv path leaving scratch poisoned).
2524 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2525 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2526 if let Some(c0) = caches.first() {
2527 if !c0.paged_block_indices.is_empty() {
2528 alloc.free(&c0.paged_block_indices);
2529 }
2530 }
2531 for c in caches.iter_mut() {
2532 c.paged_block_indices.clear();
2533 }
2534 }
2535 self.kv_free_pool.push(caches);
2536 }
2537 }
2538
2539 fn reset(&mut self) {
2540 let mut ctx = B::new_context();
2541 B::sync(&mut ctx);
2542 B::reset_graph(&mut ctx);
2543 B::sync(&mut ctx);
2544 self.kv_caches.clear();
2545 self.kv_free_pool.clear();
2546 }
2547}
2548
2549/// Batched MoE FFN — decode (m=1) and per-token-prefill (m>1 looped).
2550///
2551/// Three batched `gemv_quant_moe_id` dispatches per token: gate (broadcast
2552/// activation), up (broadcast activation), down (per-slot activation —
2553/// each expert sees its own silu·up). The per-(token, expert) outer loop
2554/// shrinks from `top_k * 4` dispatches per layer to **3 batched + 1
2555/// silu_mul_split + 1 weighted_sum_dispatch_loop**.
2556///
2557/// For prefill (m > 1) we loop over tokens externally — each token's
2558/// router output drives a single batched call. Still much faster than
2559/// the per-(token, expert) per-Linear path because the gemvs are batched.
2560///
2561/// Free function (not a method) so the caller can split the borrow on
2562/// `self` between `moe_layers[li]` (immutable) and `scratch` (mutable).
2563#[allow(clippy::too_many_arguments)]
2564fn moe_forward_stacked_decode_impl<B: Backend>(
2565 ctx: &mut B::Context,
2566 moe_layer: &Qwen3MoeLayerState<B>,
2567 scratch: &mut Qwen3MoeScratch<B>,
2568 h: usize,
2569 inter: usize,
2570 top_k: usize,
2571 n_exp: usize,
2572 norm_topk_prob: bool,
2573 tokens: usize,
2574 residual: &mut B::Buffer,
2575 // If `Some`, fold the NEXT layer's leading rms_norm into the
2576 // weighted-sum-residual tail using `weighted_sum_residual_norm_stacked`.
2577 next_norm_w: Option<&B::Buffer>,
2578 eps: f32,
2579) -> Result<()> {
2580 // GPU-side routing: one Metal launch reads router_logits and writes
2581 // selected ids + combine weights directly into device-side scratch
2582 // buffers. Eliminates the per-layer `B::sync + B::to_vec(router_logits)
2583 // + host route()` round trip — the dominant remaining cost in the
2584 // decode hot path (~10% of total decode latency).
2585 let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
2586 let stage_t0 = || -> Option<std::time::Instant> {
2587 if prof {
2588 Some(std::time::Instant::now())
2589 } else {
2590 None
2591 }
2592 };
2593 let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
2594 if let Some(t) = t0 {
2595 B::sync(ctx);
2596 c.fetch_add(
2597 t.elapsed().as_micros() as u64,
2598 std::sync::atomic::Ordering::Relaxed,
2599 );
2600 }
2601 };
2602
2603 let t0 = stage_t0();
2604 B::route_topk_softmax(
2605 ctx,
2606 &scratch.router_logits,
2607 &mut scratch.ids_buf,
2608 &mut scratch.weights_buf,
2609 tokens,
2610 n_exp,
2611 top_k,
2612 norm_topk_prob,
2613 )?;
2614 stage_end(t0, ctx, &DEC_ROUTE_US);
2615
2616 let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2617 let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2618 let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2619
2620 // moe_forward_stacked_decode_impl is only called when `tokens == 1`
2621 // (the branch in `forward_layer` routes prefill m>1 through
2622 // `moe_forward_batched_prefill_impl` instead). The for-b loop and
2623 // the copy norm_out[b*h] → x_single were vestigial scaffolding;
2624 // for tokens=1 norm_out[0..h] IS the activation row, and we can
2625 // pass it straight to the gemv kernel via src1_stride=0 broadcast.
2626 debug_assert_eq!(
2627 tokens, 1,
2628 "moe_forward_stacked_decode_impl expects tokens=1 (prefill goes through moe_forward_batched_prefill_impl)"
2629 );
2630 let _ = tokens; // silence unused-warning when assertion is compiled out
2631
2632 {
2633 // ids_buf and weights_buf populated by the GPU router above —
2634 // no host writes needed here in the decode path.
2635
2636 // Fused-vs-unfused gate+up+silu selection.
2637 //
2638 // Default: when the backend advertises support (Metal Q4KExperts),
2639 // run the single fused dispatch — saves 2 dispatches and the
2640 // entire round-trip through gate_out_stacked / up_out_stacked
2641 // scratch (≈4× [top_k, ffn] of intermediate bandwidth).
2642 //
2643 // Opt-out: `FERRUM_MOE_FUSED_GATE_UP_SILU=0` forces the legacy
2644 // 3-dispatch path. Used for A/B benchmarking and as a kill switch
2645 // if the fused kernel ever produces divergent outputs.
2646 // Cache the env-flag read once per process — the decode hot
2647 // path calls this fn ~48 layers × ~steps_per_run times.
2648 static FUSED_DISABLED: OnceLock<bool> = OnceLock::new();
2649 let fused_disabled = *FUSED_DISABLED
2650 .get_or_init(|| std::env::var("FERRUM_MOE_FUSED_GATE_UP_SILU").as_deref() == Ok("0"));
2651 let use_fused = B::supports_fused_moe_gate_up_silu() && !fused_disabled;
2652
2653 if use_fused {
2654 // 1+2+3 fused: silu_stacked = SiLU(gate · norm_out) * (up · norm_out)
2655 let t0 = stage_t0();
2656 B::gemv_quant_moe_id_gate_up_silu(
2657 ctx,
2658 &scratch.norm_out,
2659 gate_stacked,
2660 up_stacked,
2661 &scratch.ids_buf,
2662 &mut scratch.silu_stacked,
2663 top_k,
2664 )?;
2665 stage_end(t0, ctx, &DEC_SILU_US);
2666 } else {
2667 // 1. Batched gate gemv — broadcast input across top_k slots.
2668 // src1 = norm_out (which has hidden floats at offset 0),
2669 // src1_stride=0 → all slots read the same row.
2670 let t0 = stage_t0();
2671 B::gemv_quant_moe_id(
2672 ctx,
2673 &scratch.norm_out,
2674 gate_stacked,
2675 &scratch.ids_buf,
2676 &mut scratch.gate_out_stacked,
2677 top_k,
2678 0, // broadcast
2679 )?;
2680 stage_end(t0, ctx, &DEC_GATE_US);
2681
2682 // 2. Batched up gemv — also broadcast.
2683 let t0 = stage_t0();
2684 B::gemv_quant_moe_id(
2685 ctx,
2686 &scratch.norm_out,
2687 up_stacked,
2688 &scratch.ids_buf,
2689 &mut scratch.up_out_stacked,
2690 top_k,
2691 0,
2692 )?;
2693 stage_end(t0, ctx, &DEC_UP_US);
2694
2695 // 3. Stacked SiLU·gate → silu_stacked. Single dispatch covers
2696 // all top_k slots — replaces the per-slot loop's
2697 // (3 copy_slice + 1 silu_mul) × 8 = 32 dispatches.
2698 let t0 = stage_t0();
2699 B::silu_mul_stacked(
2700 ctx,
2701 &scratch.gate_out_stacked,
2702 &scratch.up_out_stacked,
2703 &mut scratch.silu_stacked,
2704 top_k,
2705 inter,
2706 )?;
2707 stage_end(t0, ctx, &DEC_SILU_US);
2708 }
2709
2710 // 4. Batched down gemv — per-slot input via src1_stride = inter.
2711 // silu_stacked[k * inter ..] is the activation row for slot k.
2712 let t0 = stage_t0();
2713 B::gemv_quant_moe_id(
2714 ctx,
2715 &scratch.silu_stacked,
2716 down_stacked,
2717 &scratch.ids_buf,
2718 &mut scratch.down_out_stacked,
2719 top_k,
2720 inter,
2721 )?;
2722 stage_end(t0, ctx, &DEC_DOWN_US);
2723
2724 // 5. Fused weighted-sum + residual-add (+ optional next-layer
2725 // rms_norm). Two paths:
2726 //
2727 // * `next_norm_w = Some(_)` (cross-layer fusion): one kernel
2728 // computes residual[i] += Σ_k w[k] · down[k, i] AND
2729 // norm_out[i] = residual[i] · scale · next_norm_w[i].
2730 // The next layer's leading rms_norm is skipped. Saves an
2731 // additional dispatch per layer transition.
2732 // * `next_norm_w = None` (last layer): just residual-add.
2733 let t0 = stage_t0();
2734 if let Some(nnw) = next_norm_w {
2735 B::weighted_sum_residual_norm_stacked(
2736 ctx,
2737 &scratch.down_out_stacked,
2738 &scratch.weights_buf,
2739 residual,
2740 nnw,
2741 &mut scratch.norm_out,
2742 top_k,
2743 h,
2744 eps,
2745 )?;
2746 } else {
2747 B::weighted_sum_residual_stacked(
2748 ctx,
2749 &scratch.down_out_stacked,
2750 &scratch.weights_buf,
2751 residual,
2752 top_k,
2753 h,
2754 )?;
2755 }
2756 stage_end(t0, ctx, &DEC_WSUM_US);
2757 }
2758
2759 Ok(())
2760}
2761
2762/// Batched MoE FFN for prefill (m > 1).
2763///
2764/// One pass through the expert dispatch — replaces the per-token loop
2765/// with three batched 2-D mul_mm_id dispatches (gate, up, down) where
2766/// each expert's slab of (token, slot) pairs runs as one gemm tile.
2767/// Per-layer dispatch count: ~6 (router + 3 mul_mm_id + silu + wsum)
2768/// independent of `tokens`. Compare to the decode-style stacked path
2769/// that emits ~10 per token.
2770///
2771/// Free function so the caller can split the borrow on `self` between
2772/// `moe_layers[li]` (immutable) and `scratch` (mutable).
2773#[allow(clippy::too_many_arguments)]
2774fn moe_forward_batched_prefill_impl<B: Backend>(
2775 ctx: &mut B::Context,
2776 moe_layer: &Qwen3MoeLayerState<B>,
2777 scratch: &mut Qwen3MoeScratch<B>,
2778 h: usize,
2779 inter: usize,
2780 top_k: usize,
2781 n_exp: usize,
2782 norm_topk_prob: bool,
2783 tokens: usize,
2784) -> Result<()> {
2785 let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
2786 let stage_t0 = || -> Option<std::time::Instant> {
2787 if prof {
2788 Some(std::time::Instant::now())
2789 } else {
2790 None
2791 }
2792 };
2793 let stage_end =
2794 |t0: Option<std::time::Instant>, ctx: &mut B::Context, us: &AtomicU64, n: &AtomicU64| {
2795 if let Some(t) = t0 {
2796 B::sync(ctx);
2797 us.fetch_add(
2798 t.elapsed().as_micros() as u64,
2799 std::sync::atomic::Ordering::Relaxed,
2800 );
2801 n.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2802 }
2803 };
2804
2805 // GPU-side routing: keep the whole pipeline device-resident. Two
2806 // dispatches replace the per-layer `B::sync + to_vec(router_logits)
2807 // + host route() + host compute_ids_tpe + write_back` round trip.
2808 //
2809 // 1. `route_topk_softmax` writes selected expert IDs (flat
2810 // `[batch, top_k]`) into `selected_ids_buf` and the post-renorm
2811 // combine weights directly into `weights_2d`.
2812 // 2. `compute_ids_tpe_gpu` buckets those pairs into `tpe_buf` and
2813 // `ids_2d` using device-side atomic_fetch_add slot claims. The
2814 // `ids_2d` row stride is the worst-case `tokens * top_k`; the
2815 // consumer GEMM stops at `tpe[e]` so the over-strided columns
2816 // cost only launch overhead, not real compute.
2817 //
2818 // `FERRUM_MOE_HOST_TOPK=1` → legacy CPU softmax+topk+bucket
2819 // `FERRUM_MOE_DIRECT_DISPATCH=1` → GPU topk but worst-case GEMM grid
2820 // (default) → GPU topk + indirect-dispatched GEMM
2821 // (grid sized from max(tpe[e]))
2822 let use_gpu_topk = std::env::var("FERRUM_MOE_HOST_TOPK").as_deref() != Ok("1");
2823 let use_indirect_dispatch =
2824 use_gpu_topk && std::env::var("FERRUM_MOE_DIRECT_DISPATCH").as_deref() != Ok("1");
2825 let max_per_expert = if use_gpu_topk {
2826 let t0 = stage_t0();
2827 B::route_topk_softmax(
2828 ctx,
2829 &scratch.router_logits,
2830 &mut scratch.selected_ids_buf,
2831 &mut scratch.weights_2d,
2832 tokens,
2833 n_exp,
2834 top_k,
2835 norm_topk_prob,
2836 )?;
2837 B::compute_ids_tpe_gpu(
2838 ctx,
2839 &scratch.selected_ids_buf,
2840 &mut scratch.tpe_buf,
2841 &mut scratch.ids_2d,
2842 &mut scratch.gate_up_args_buf,
2843 &mut scratch.down_args_buf,
2844 tokens,
2845 n_exp,
2846 top_k,
2847 inter,
2848 h,
2849 )?;
2850 stage_end(
2851 t0,
2852 ctx,
2853 &MOE_PREFILL_HOST_TOPK_US,
2854 &MOE_PREFILL_HOST_TOPK_CALLS,
2855 );
2856 // Worst-case ids row stride; matches `dispatch_compute_ids_tpe`.
2857 tokens * top_k
2858 } else {
2859 use ferrum_kernels::moe_host::compute_ids_tpe;
2860 let t0 = stage_t0();
2861 B::sync(ctx);
2862 let logits_host = B::to_vec(&scratch.router_logits, tokens * n_exp);
2863 let route = crate::moe::router::route(&logits_host, tokens, n_exp, top_k, norm_topk_prob);
2864 let (tpe_host, ids_host, max_per_expert) =
2865 compute_ids_tpe(&route.expert_ids, n_exp, tokens, top_k);
2866 B::write_i32_into(&mut scratch.tpe_buf, &tpe_host);
2867 B::write_i32_into(&mut scratch.ids_2d, &ids_host);
2868 B::write_f32_into(&mut scratch.weights_2d, &route.expert_weights);
2869 stage_end(
2870 t0,
2871 ctx,
2872 &MOE_PREFILL_HOST_TOPK_US,
2873 &MOE_PREFILL_HOST_TOPK_CALLS,
2874 );
2875 max_per_expert
2876 };
2877
2878 let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2879 let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2880 let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2881
2882 // 1. Batched gate gemm — one launch covers all (token, expert) pairs.
2883 // src1 layout: [batch, ne11=1, K] (broadcast: each pair reads its
2884 // token's row, slot index ignored).
2885 // dst layout: [batch, top_k, expert_inter] — natural.
2886 let t0 = stage_t0();
2887 if use_indirect_dispatch {
2888 B::gemm_quant_moe_id_indirect(
2889 ctx,
2890 &scratch.norm_out,
2891 gate_stacked,
2892 &scratch.ids_2d,
2893 &scratch.tpe_buf,
2894 &mut scratch.gate_out_stacked,
2895 &scratch.gate_up_args_buf,
2896 1, // ne11 = 1: broadcast
2897 top_k,
2898 max_per_expert,
2899 tokens,
2900 )?;
2901 } else {
2902 B::gemm_quant_moe_id(
2903 ctx,
2904 &scratch.norm_out,
2905 gate_stacked,
2906 &scratch.ids_2d,
2907 &scratch.tpe_buf,
2908 &mut scratch.gate_out_stacked,
2909 1,
2910 top_k,
2911 max_per_expert,
2912 tokens,
2913 )?;
2914 }
2915 stage_end(t0, ctx, &MOE_PREFILL_GATE_US, &MOE_PREFILL_GATE_CALLS);
2916
2917 // 2. Batched up gemm — same shape as gate.
2918 let t0 = stage_t0();
2919 if use_indirect_dispatch {
2920 B::gemm_quant_moe_id_indirect(
2921 ctx,
2922 &scratch.norm_out,
2923 up_stacked,
2924 &scratch.ids_2d,
2925 &scratch.tpe_buf,
2926 &mut scratch.up_out_stacked,
2927 &scratch.gate_up_args_buf,
2928 1,
2929 top_k,
2930 max_per_expert,
2931 tokens,
2932 )?;
2933 } else {
2934 B::gemm_quant_moe_id(
2935 ctx,
2936 &scratch.norm_out,
2937 up_stacked,
2938 &scratch.ids_2d,
2939 &scratch.tpe_buf,
2940 &mut scratch.up_out_stacked,
2941 1,
2942 top_k,
2943 max_per_expert,
2944 tokens,
2945 )?;
2946 }
2947 stage_end(t0, ctx, &MOE_PREFILL_UP_US, &MOE_PREFILL_UP_CALLS);
2948
2949 // 3. SiLU·gate over [tokens * top_k, expert_inter] flat layout.
2950 let total_pairs = tokens * top_k;
2951 let t0 = stage_t0();
2952 B::silu_mul_batched(
2953 ctx,
2954 &scratch.gate_out_stacked,
2955 &scratch.up_out_stacked,
2956 &mut scratch.silu_stacked,
2957 total_pairs,
2958 inter,
2959 )?;
2960 stage_end(t0, ctx, &MOE_PREFILL_SILU_US, &MOE_PREFILL_SILU_CALLS);
2961
2962 // 4. Batched down gemm — src1 is [batch, top_k, expert_inter] from
2963 // silu_stacked. ne11 = top_k → each pair reads its own row.
2964 let t0 = stage_t0();
2965 if use_indirect_dispatch {
2966 B::gemm_quant_moe_id_indirect(
2967 ctx,
2968 &scratch.silu_stacked,
2969 down_stacked,
2970 &scratch.ids_2d,
2971 &scratch.tpe_buf,
2972 &mut scratch.down_out_stacked,
2973 &scratch.down_args_buf,
2974 top_k, // ne11 = top_k: per-slot
2975 top_k,
2976 max_per_expert,
2977 tokens,
2978 )?;
2979 } else {
2980 B::gemm_quant_moe_id(
2981 ctx,
2982 &scratch.silu_stacked,
2983 down_stacked,
2984 &scratch.ids_2d,
2985 &scratch.tpe_buf,
2986 &mut scratch.down_out_stacked,
2987 top_k,
2988 top_k,
2989 max_per_expert,
2990 tokens,
2991 )?;
2992 }
2993 stage_end(t0, ctx, &MOE_PREFILL_DOWN_US, &MOE_PREFILL_DOWN_CALLS);
2994
2995 // 5. Per-batch weighted sum: moe_out[b, h] = Σ_k w[b,k] · down[b,k,h]
2996 let t0 = stage_t0();
2997 B::weighted_sum_batched(
2998 ctx,
2999 &scratch.down_out_stacked,
3000 &scratch.weights_2d,
3001 &mut scratch.moe_out,
3002 tokens,
3003 top_k,
3004 h,
3005 )?;
3006 stage_end(t0, ctx, &MOE_PREFILL_WSUM_US, &MOE_PREFILL_WSUM_CALLS);
3007
3008 Ok(())
3009}
3010
3011/// Batched MoE FFN for the **small-m decode** range (typically c=2..32).
3012///
3013/// Mirrors llama.cpp's `kernel_mul_mv_id` strategy: hold the dispatch
3014/// count flat as concurrency scales by emitting **one** batched GEMV
3015/// per linear (gate / up / down) that covers all `m * top_k`
3016/// (token, expert) pairs in a single Metal launch. Replaces the
3017/// per-token outer loop in `forward_layer` (which emitted ~5
3018/// dispatches × m tokens per layer) with a fixed-shape pipeline.
3019///
3020/// Compared to [`moe_forward_batched_prefill_impl`]:
3021/// * no `compute_ids_tpe_gpu` bucketing kernel (the new pair-indexed
3022/// GEMV reads `selected_ids_buf` directly)
3023/// * uses GEMV not GEMM (better tile utilisation when tokens-per-expert
3024/// is small — at c=16 with top_k=8 each expert sees ~1-3 token rows,
3025/// well below the simdgroup_matmul tile width)
3026/// * fewer Metal dispatches per layer (5: route + 3 gemv + silu + wsum)
3027///
3028/// Per-layer dispatch budget: 5 (independent of m). At c=16 / 48 layers
3029/// that's 240 dispatches per decode step vs the per-token loop's ~3,840.
3030#[allow(clippy::too_many_arguments)]
3031fn moe_forward_batched_decode_impl<B: Backend>(
3032 ctx: &mut B::Context,
3033 moe_layer: &Qwen3MoeLayerState<B>,
3034 scratch: &mut Qwen3MoeScratch<B>,
3035 h: usize,
3036 inter: usize,
3037 top_k: usize,
3038 n_exp: usize,
3039 norm_topk_prob: bool,
3040 tokens: usize,
3041) -> Result<()> {
3042 let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
3043 let stage_t0 = || -> Option<std::time::Instant> {
3044 if prof {
3045 Some(std::time::Instant::now())
3046 } else {
3047 None
3048 }
3049 };
3050 let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
3051 if let Some(t) = t0 {
3052 B::sync(ctx);
3053 c.fetch_add(
3054 t.elapsed().as_micros() as u64,
3055 std::sync::atomic::Ordering::Relaxed,
3056 );
3057 }
3058 };
3059
3060 let total_pairs = tokens * top_k;
3061
3062 // 1. Single batched router pass — fills selected_ids_buf [m * top_k]
3063 // and weights_2d [m * top_k] in one Metal dispatch.
3064 let t0 = stage_t0();
3065 B::route_topk_softmax(
3066 ctx,
3067 &scratch.router_logits,
3068 &mut scratch.selected_ids_buf,
3069 &mut scratch.weights_2d,
3070 tokens,
3071 n_exp,
3072 top_k,
3073 norm_topk_prob,
3074 )?;
3075 stage_end(t0, ctx, &MOE_BATCHED_DECODE_ROUTE_US);
3076
3077 let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
3078 let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
3079 let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
3080
3081 // 2+3+4. Fused gate+up+silu — single Metal dispatch covers all
3082 // m*top_k pairs. Falls back to the 3-dispatch sequence on backends
3083 // that don't have the fused-batched kernel.
3084 if B::supports_batched_moe_gate_up_silu() {
3085 let t0 = stage_t0();
3086 B::gemv_quant_moe_id_gate_up_silu_batched(
3087 ctx,
3088 &scratch.norm_out,
3089 gate_stacked,
3090 up_stacked,
3091 &scratch.selected_ids_buf,
3092 &mut scratch.silu_stacked,
3093 tokens,
3094 top_k,
3095 h, // outer stride: K floats per token
3096 0, // inner stride: 0 (slots within a token broadcast)
3097 )?;
3098 // Charge the whole fused step to the SiLU bucket — keeps the
3099 // profile counter additive with the unfused path's silu line.
3100 stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
3101 } else {
3102 // 2. Batched gate gemv — one launch covers all m*top_k pairs.
3103 let t0 = stage_t0();
3104 B::gemv_quant_moe_id_batched(
3105 ctx,
3106 &scratch.norm_out,
3107 gate_stacked,
3108 &scratch.selected_ids_buf,
3109 &mut scratch.gate_out_stacked,
3110 tokens,
3111 top_k,
3112 h,
3113 0,
3114 )?;
3115 stage_end(t0, ctx, &MOE_BATCHED_DECODE_GATE_US);
3116
3117 // 3. Batched up gemv.
3118 let t0 = stage_t0();
3119 B::gemv_quant_moe_id_batched(
3120 ctx,
3121 &scratch.norm_out,
3122 up_stacked,
3123 &scratch.selected_ids_buf,
3124 &mut scratch.up_out_stacked,
3125 tokens,
3126 top_k,
3127 h,
3128 0,
3129 )?;
3130 stage_end(t0, ctx, &MOE_BATCHED_DECODE_UP_US);
3131
3132 // 4. SiLU·gate.
3133 let t0 = stage_t0();
3134 B::silu_mul_batched(
3135 ctx,
3136 &scratch.gate_out_stacked,
3137 &scratch.up_out_stacked,
3138 &mut scratch.silu_stacked,
3139 total_pairs,
3140 inter,
3141 )?;
3142 stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
3143 }
3144
3145 // 5. Batched down gemv — src1 = silu_stacked [m, top_k, ffn]: each
3146 // pair has its own row, outer = top_k * ffn, inner = ffn.
3147 let t0 = stage_t0();
3148 B::gemv_quant_moe_id_batched(
3149 ctx,
3150 &scratch.silu_stacked,
3151 down_stacked,
3152 &scratch.selected_ids_buf,
3153 &mut scratch.down_out_stacked,
3154 tokens,
3155 top_k,
3156 top_k * inter, // outer: top_k * ffn floats per token
3157 inter, // inner: ffn floats per slot
3158 )?;
3159 stage_end(t0, ctx, &MOE_BATCHED_DECODE_DOWN_US);
3160
3161 // 6. Per-token weighted sum across slots → moe_out [m, h]. Caller
3162 // does residual += moe_out at the end of forward_layer.
3163 let t0 = stage_t0();
3164 B::weighted_sum_batched(
3165 ctx,
3166 &scratch.down_out_stacked,
3167 &scratch.weights_2d,
3168 &mut scratch.moe_out,
3169 tokens,
3170 top_k,
3171 h,
3172 )?;
3173 stage_end(t0, ctx, &MOE_BATCHED_DECODE_WSUM_US);
3174
3175 Ok(())
3176}
3177
3178/// Build a stub Linear<B> with the given shape but zero weights. Used to
3179/// fill the dense `gate_up_proj` / `down_proj` slots in `LlamaFamilyLayer`
3180/// for MoE models — those slots are never invoked because the MoE FFN
3181/// path runs through `moe_layer.experts` instead. The stub's only purpose
3182/// is to satisfy the struct's type signature with minimal memory cost.
3183fn stub_linear<B: Backend>(
3184 out_features: usize,
3185 in_features: usize,
3186) -> Box<dyn ferrum_quantization::Linear<B>> {
3187 // Zero-init: out_features * in_features f32. For a 30B-A3B layer
3188 // this is 2*768*2048 = 3.1M f32 = 12 MB → fine; per-layer overhead
3189 // ≈ 12 MB × 48 = 576 MB. Marginal vs the experts (~16 GB).
3190 let zeros = vec![0.0f32; out_features * in_features];
3191 Box::new(ferrum_quantization::DenseLinear::<B>::from_rows(
3192 &zeros,
3193 out_features,
3194 in_features,
3195 ))
3196}
3197
3198fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
3199 let hd = cfg.head_dim;
3200 let half = hd / 2;
3201 let max = cfg.max_seq_len;
3202 let mut cos = vec![0.0f32; max * half];
3203 let mut sin = vec![0.0f32; max * half];
3204 for pos in 0..max {
3205 for i in 0..half {
3206 let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
3207 let angle = pos as f64 * freq;
3208 cos[pos * half + i] = angle.cos() as f32;
3209 sin[pos * half + i] = angle.sin() as f32;
3210 }
3211 }
3212 RopeCache {
3213 cos: B::from_slice(&cos),
3214 sin: B::from_slice(&sin),
3215 }
3216}