Skip to main content

ferrum_models/models/
qwen3_moe.rs

1//! `Qwen3MoeModel<B>` — Qwen3-MoE family decoder (Qwen3-30B-A3B and friends).
2//!
3//! Architectural delta vs [`LlamaFamilyModel`]:
4//!   * Each transformer layer's FFN is a top-K MoE block instead of a
5//!     fused `gate_up_proj → silu → down_proj` MLP.
6//!     - One small router linear (`[hidden] → [num_experts]`) picks
7//!       top-K experts per token.
8//!     - Each expert is itself a fused `gate_up + down` MLP with the
9//!       same SwiGLU + RMSNorm structure as the dense path, just with
10//!       `expert_intermediate_size` (typically much smaller than the
11//!       dense `intermediate_size`).
12//!     - Output is the weight-summed combination of the K selected
13//!       expert outputs.
14//!   * Attention path is unchanged from dense Qwen3 (GQA + QK-norm + RoPE).
15//!
16//! Implementation re-uses the dense layer's attention machinery
17//! verbatim — RMSNorm, fused QKV, QK-norm + RoPE, KV cache append,
18//! flash attention, O-projection, residual + post-norm. The only new
19//! code is the MoE FFN block at the tail of each layer's forward.
20//!
21//! Memory model: experts are loaded as `QuantLinear<B>` per expert,
22//! slicing the on-disk 3-D `ffn_{gate,up,down}_exps.weight` tensors
23//! byte-wise so weights stay compressed (Q4_K / Q6_K). For a 32 GB
24//! Mac to run Qwen3-30B-A3B at all, this is non-negotiable: an
25//! eager-fp32 expert stack would weigh ~110 GB.
26
27use std::collections::HashMap;
28use std::sync::atomic::AtomicU64;
29use std::sync::OnceLock;
30
31use ferrum_kernels::backend::{Backend, KvCache};
32use ferrum_quantization::WeightLoader;
33use ferrum_types::{FerrumError, Result};
34
35use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
36use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyLayer, RopeCache};
37use crate::moe::{moe_forward, ExpertStack};
38use crate::moe_config::Qwen3MoeConfig;
39
40// Decode-side per-op profile counters — same names as the dense path
41// so existing tooling (`FERRUM_DECODE_OP_PROFILE=1` log scrapers) keeps
42// working without a separate switch for MoE.
43static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
44static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
45static MOE_TIME_US: AtomicU64 = AtomicU64::new(0);
46static MOE_CALLS: AtomicU64 = AtomicU64::new(0);
47
48// Fine-grained decode-only counters, populated by
49// `moe_forward_stacked_decode_impl` when FERRUM_DECODE_OP_PROFILE is set.
50// Each is per-layer summed over the layers in one decode token; drained
51// at the bottom of `decode_internal`.
52static DEC_ROUTE_US: AtomicU64 = AtomicU64::new(0);
53static DEC_GATE_US: AtomicU64 = AtomicU64::new(0);
54static DEC_UP_US: AtomicU64 = AtomicU64::new(0);
55static DEC_SILU_US: AtomicU64 = AtomicU64::new(0);
56static DEC_DOWN_US: AtomicU64 = AtomicU64::new(0);
57static DEC_WSUM_US: AtomicU64 = AtomicU64::new(0);
58// Single-shot per decode token (not per-layer).
59static DEC_EMBED_US: AtomicU64 = AtomicU64::new(0);
60static DEC_FINAL_NORM_US: AtomicU64 = AtomicU64::new(0);
61static DEC_LM_HEAD_US: AtomicU64 = AtomicU64::new(0);
62
63// MoE batched-prefill sub-stage counters (gate / up / down mul_mm_id +
64// silu + weighted_sum + host topk). Same FERRUM_DECODE_OP_PROFILE gate.
65static MOE_PREFILL_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
66static MOE_PREFILL_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
67static MOE_PREFILL_GATE_US: AtomicU64 = AtomicU64::new(0);
68static MOE_PREFILL_GATE_CALLS: AtomicU64 = AtomicU64::new(0);
69static MOE_PREFILL_UP_US: AtomicU64 = AtomicU64::new(0);
70static MOE_PREFILL_UP_CALLS: AtomicU64 = AtomicU64::new(0);
71static MOE_PREFILL_SILU_US: AtomicU64 = AtomicU64::new(0);
72static MOE_PREFILL_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
73static MOE_PREFILL_DOWN_US: AtomicU64 = AtomicU64::new(0);
74static MOE_PREFILL_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
75static MOE_PREFILL_WSUM_US: AtomicU64 = AtomicU64::new(0);
76static MOE_PREFILL_WSUM_CALLS: AtomicU64 = AtomicU64::new(0);
77
78// MoE batched-DECODE sub-stage counters (small-m path that uses the
79// batched-pair GEMV in place of the per-token loop).
80static MOE_BATCHED_DECODE_ROUTE_US: AtomicU64 = AtomicU64::new(0);
81static MOE_BATCHED_DECODE_GATE_US: AtomicU64 = AtomicU64::new(0);
82static MOE_BATCHED_DECODE_UP_US: AtomicU64 = AtomicU64::new(0);
83static MOE_BATCHED_DECODE_SILU_US: AtomicU64 = AtomicU64::new(0);
84static MOE_BATCHED_DECODE_DOWN_US: AtomicU64 = AtomicU64::new(0);
85static MOE_BATCHED_DECODE_WSUM_US: AtomicU64 = AtomicU64::new(0);
86
87// Coarse stage counters for `forward_layer_batched_decode` so we can
88// see where the time goes without per-op instrumentation. Summed
89// across all layers in one decode_batch_internal call.
90static BD_DENSE_US: AtomicU64 = AtomicU64::new(0); // rms_norm + qkv_proj + split_qkv + o_proj + fused_add_rms_norm
91static BD_ATTN_PERITEM_US: AtomicU64 = AtomicU64::new(0); // the for-i in 0..m attention loop (incl. plumbing)
92static BD_MOE_US: AtomicU64 = AtomicU64::new(0); // router + MoE FFN + residual add
93static BD_LAYER_CALLS: AtomicU64 = AtomicU64::new(0);
94
95/// Per-layer MoE state: router linear (small) + per-expert MLP stack.
96pub struct Qwen3MoeLayerState<B: Backend> {
97    /// Router projection `[hidden] → [num_experts]` — tiny, never sparse,
98    /// always runs the full GEMV.
99    pub router: Box<dyn ferrum_quantization::Linear<B>>,
100    /// Per-expert weight stack. Each entry's `gate_up` is the fused
101    /// `[gate; up]` projection; `down` is the post-SwiGLU output proj.
102    pub experts: ExpertStack<B>,
103}
104
105/// Reusable scratch buffers for the MoE forward path. All sized at
106/// allocation time and reused across layers / forward calls.
107pub struct Qwen3MoeScratch<B: Backend> {
108    /// See [`crate::models::llama_family::LlamaFamilyScratch`] for the
109    /// attention scratch — we re-use those names verbatim.
110    pub residual: Option<B::Buffer>,
111    pub norm_out: B::Buffer,
112    pub qkv_out: B::Buffer,
113    pub q_buf: B::Buffer,
114    pub k_buf: B::Buffer,
115    pub v_buf: B::Buffer,
116    pub q_head_major: B::Buffer,
117    pub k_head_major: B::Buffer,
118    pub v_head_major: B::Buffer,
119    pub attn_head_major_out: B::Buffer,
120    pub attn_flat: B::Buffer,
121    pub o_proj_out: B::Buffer,
122
123    // ── MoE-specific scratch ─────────────────────────────────────────
124    /// Router logits for the whole batch: `[max_tokens, num_experts]`.
125    pub router_logits: B::Buffer,
126    /// Per-(token, expert) gate||up projection output — `[2 * expert_inter]`.
127    pub gate_up_buf: B::Buffer,
128    /// SiLU(gate) * up scratch — `[expert_inter]`.
129    pub silu_buf: B::Buffer,
130    /// Per-(token, expert) down-projection output — `[hidden]`.
131    pub down_buf: B::Buffer,
132    /// Per-token input row scratch — `[hidden]`. Holds the post-RMSNorm
133    /// activation slice that the per-(expert) gate_up gemv reads, kept
134    /// stable across the entire top_k loop for one token.
135    pub x_single: B::Buffer,
136    /// Per-token output accumulator — `[hidden]`. Holds the running
137    /// `Σ_k weight_k · expert_k(x[b])` sum that grows across the top_k
138    /// loop and is flushed to `moe_out[b]` once per token.
139    pub acc_buf: B::Buffer,
140    /// MoE output `[max_tokens, hidden]`. Zeroed each forward.
141    pub moe_out: B::Buffer,
142    /// Pre-allocated `[hidden]` zero scratch — `acc_buf` is reset to
143    /// this each token without going through `B::from_slice` on the
144    /// hot path.
145    pub zero_hidden: B::Buffer,
146
147    // ── MoE batched-fast-path scratch (Metal `gemv_q*kw_moe_id_f32` /
148    //    `gemm_q*kw_moe_id_f32`) ─────────────────────────────────────
149    //
150    // Sized for `max_tokens * top_k * X` so the same buffers cover both
151    // decode (m=1, uses the first `top_k * X` slice) and prefill
152    // (m>1, uses the full `max_tokens * top_k * X`). Decode-only
153    // workloads pay no extra memory because `max_tokens` was 1 there.
154    /// `[max_tokens * top_k * expert_inter]` — gate gemm output per pair.
155    pub gate_out_stacked: B::Buffer,
156    /// `[max_tokens * top_k * expert_inter]` — up gemm output per pair.
157    pub up_out_stacked: B::Buffer,
158    /// `[max_tokens * top_k * expert_inter]` — SiLU(gate)·up per pair.
159    pub silu_stacked: B::Buffer,
160    /// `[max_tokens * top_k * hidden]` — down gemm output per pair.
161    pub down_out_stacked: B::Buffer,
162    /// `[top_k]` i32 expert IDs for the current token (decode reuses;
163    /// prefill writes per-pair indices into `ids_2d` instead).
164    pub ids_buf: B::Buffer,
165    /// `[top_k]` f32 router combine weights for the current decode
166    /// token. Decode hot-path uses `write_f32_into` to update.
167    pub weights_buf: B::Buffer,
168    /// `[max_tokens * top_k]` i32 — flat selected-expert IDs from the
169    /// GPU router for the prefill batch. Consumed by `compute_ids_tpe_gpu`
170    /// to bucket pairs by expert into `tpe_buf` / `ids_2d`.
171    pub selected_ids_buf: B::Buffer,
172    /// `[3]` u32 indirect-dispatch args (`grid_x, grid_y, grid_z`) for
173    /// the gate / up MoE GEMM. Written by `compute_ids_tpe_gpu` so the
174    /// consumer GEMM grid covers exactly `max(tpe[e])` columns instead
175    /// of the worst-case `tokens * top_k`.
176    pub gate_up_args_buf: B::Buffer,
177    /// Same shape as `gate_up_args_buf` but for the down MoE GEMM
178    /// (different `grid_y` because down's `M = hidden_size` vs gate/up's
179    /// `M = expert_intermediate_size`).
180    pub down_args_buf: B::Buffer,
181    /// `[num_experts * max_per_expert_max]` i32 — per-expert pair
182    /// index lists for prefill 2-D mul_mm_id. `max_per_expert_max`
183    /// is bounded by `max_tokens * top_k` (worst-case: one expert
184    /// gets every pair). Sized at scratch alloc time.
185    pub ids_2d: B::Buffer,
186    /// `[num_experts]` i32 — `tpe[e]` = number of pairs assigned to
187    /// expert `e`. Companion to `ids_2d`.
188    pub tpe_buf: B::Buffer,
189    /// `[max_tokens * top_k]` f32 — combine weights per pair, in
190    /// natural `[batch, top_k]` layout for `weighted_sum_batched`.
191    pub weights_2d: B::Buffer,
192
193    // ── Final-token / lm_head outputs ────────────────────────────────
194    pub last_hidden: B::Buffer,
195    pub last_normed: B::Buffer,
196    pub logits: B::Buffer,
197    pub batch_logits: B::Buffer,
198
199    // ── Per-item single-token buffers for decode_batch (Phase 4b) ────
200    //
201    // The batched-decode path runs M GEMMs at m=M (qkv_proj / o_proj /
202    // router / MoE expert mul_mm_id) but attention stays a per-item loop
203    // (each cache_id has its own contiguous K/V buffer — no way to fan
204    // M items into a single attention dispatch without paged KV). These
205    // 1-token-shaped scratches hold the per-item slice during the loop:
206    // `copy_slice` extracts q/k/v from the batched buffers, qk_norm_rope
207    // writes head-major into _single, kv_cache_append + flash_attention
208    // run on it, then copy_slice writes back into attn_flat[i*q_dim].
209    //
210    // None until `enable_batched_decode_scratch` is called from
211    // `ensure_kv` once we know we'll be doing multi-seq decode.
212    pub q_single: Option<B::Buffer>,
213    pub k_single: Option<B::Buffer>,
214    pub v_single: Option<B::Buffer>,
215    pub q_head_major_single: Option<B::Buffer>,
216    pub k_head_major_single: Option<B::Buffer>,
217    pub v_head_major_single: Option<B::Buffer>,
218    pub attn_head_major_single: Option<B::Buffer>,
219
220    // ── Paged batched dispatch scratch ──────────────────────────────────
221    //
222    // Mirrors the same fields on `LlamaFamilyScratch`. `Some` only when
223    // `FERRUM_METAL_PAGED_KV=1` and `enable_paged_batch` was called once
224    // we know the pool dimensions. Sized for `FERRUM_PAGED_MAX_SEQS ×
225    // q_dim` so the multi-seq decode path can fan in M items' Q into a
226    // single batched buffer for one `paged_decode_attention(num_seqs=M)`
227    // call instead of running M sequential m=1 attentions.
228    pub paged_batch_q: Option<B::Buffer>,
229    pub paged_batch_o: Option<B::Buffer>,
230    pub paged_batch_block_tables: Option<B::Buffer>,
231    pub paged_batch_context_lens: Option<B::Buffer>,
232    pub paged_max_blocks_per_seq: usize,
233
234    pub max_tokens: usize,
235}
236
237impl<B: Backend> Qwen3MoeScratch<B> {
238    fn alloc(cfg: &Qwen3MoeConfig, max_tokens: usize) -> Self {
239        let h = cfg.base.hidden_size;
240        let q_dim = cfg.base.num_heads * cfg.base.head_dim;
241        let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
242        let qkv_dim = q_dim + 2 * kv_dim;
243        let t = max_tokens;
244        let inter = cfg.expert_intermediate_size;
245        let n_exp = cfg.num_experts;
246        let vocab = cfg.base.vocab_size;
247        Self {
248            residual: Some(B::alloc(t * h)),
249            norm_out: B::alloc(t * h),
250            qkv_out: B::alloc(t * qkv_dim),
251            q_buf: B::alloc(t * q_dim),
252            k_buf: B::alloc(t * kv_dim),
253            v_buf: B::alloc(t * kv_dim),
254            q_head_major: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
255            k_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
256            v_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
257            attn_head_major_out: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
258            attn_flat: B::alloc(t * q_dim),
259            o_proj_out: B::alloc(t * h),
260            router_logits: B::alloc(t * n_exp),
261            gate_up_buf: B::alloc(2 * inter),
262            silu_buf: B::alloc(inter),
263            down_buf: B::alloc(h),
264            x_single: B::alloc(h),
265            acc_buf: B::alloc(h),
266            moe_out: B::alloc(t * h),
267            zero_hidden: B::from_slice(&vec![0.0f32; h]),
268            gate_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
269            up_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
270            silu_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
271            down_out_stacked: B::alloc(t * cfg.num_experts_per_tok * h),
272            ids_buf: B::from_slice_i32(&vec![0i32; cfg.num_experts_per_tok]),
273            weights_buf: B::from_slice(&vec![0.0f32; cfg.num_experts_per_tok]),
274            selected_ids_buf: B::from_slice_i32(&vec![0i32; t * cfg.num_experts_per_tok]),
275            // 3 u32s per indirect args buffer; allocated as 3 i32s so we
276            // can reuse `from_slice_i32`. The kernel writes them as
277            // `device uint *` and the bit pattern is consumed by
278            // `dispatch_thread_groups_indirect`.
279            gate_up_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
280            down_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
281            ids_2d: B::from_slice_i32(&vec![0i32; n_exp * t * cfg.num_experts_per_tok]),
282            tpe_buf: B::from_slice_i32(&vec![0i32; n_exp]),
283            weights_2d: B::from_slice(&vec![0.0f32; t * cfg.num_experts_per_tok]),
284            last_hidden: B::alloc(h),
285            last_normed: B::alloc(h),
286            logits: B::alloc(vocab),
287            batch_logits: B::alloc(t * vocab),
288            // Lazily-allocated; `enable_batched_decode_scratch` populates
289            // these the first time decode_batch is called with M > 1.
290            q_single: None,
291            k_single: None,
292            v_single: None,
293            q_head_major_single: None,
294            k_head_major_single: None,
295            v_head_major_single: None,
296            attn_head_major_single: None,
297            // Lazily-allocated; `enable_paged_batch` populates these when
298            // FERRUM_METAL_PAGED_KV=1 + we know the pool dimensions.
299            paged_batch_q: None,
300            paged_batch_o: None,
301            paged_batch_block_tables: None,
302            paged_batch_context_lens: None,
303            paged_max_blocks_per_seq: 0,
304            max_tokens: t,
305        }
306    }
307
308    /// Allocate scratch for paged batched dispatch. Mirrors
309    /// `LlamaFamilyScratch::enable_paged_batch`. Idempotent.
310    fn enable_paged_batch(
311        &mut self,
312        cfg: &Qwen3MoeConfig,
313        max_seqs: usize,
314        max_blocks_per_seq: usize,
315    ) {
316        if self.paged_batch_q.is_some() {
317            return;
318        }
319        let q_dim = cfg.base.num_heads * cfg.base.head_dim;
320        self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
321        self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
322        self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
323        self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
324        self.paged_max_blocks_per_seq = max_blocks_per_seq;
325    }
326
327    /// Allocate the per-item single-token scratch buffers used by
328    /// `forward_layer_batched_decode`. Idempotent.
329    fn enable_batched_decode_scratch(&mut self, cfg: &Qwen3MoeConfig) {
330        if self.q_single.is_some() {
331            return;
332        }
333        let q_dim = cfg.base.num_heads * cfg.base.head_dim;
334        let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
335        self.q_single = Some(B::alloc(q_dim));
336        self.k_single = Some(B::alloc(kv_dim));
337        self.v_single = Some(B::alloc(kv_dim));
338        self.q_head_major_single = Some(B::alloc(q_dim));
339        self.k_head_major_single = Some(B::alloc(kv_dim));
340        self.v_head_major_single = Some(B::alloc(kv_dim));
341        self.attn_head_major_single = Some(B::alloc(q_dim));
342    }
343}
344
345/// Qwen3-MoE decoder model.
346///
347/// Holds the same per-layer attention weights as [`LlamaFamilyModel`]
348/// plus a [`Qwen3MoeLayerState`] per layer for the MoE FFN. Routing,
349/// expert dispatch, and weighted combine all happen inside
350/// [`moe_forward`]; this struct only owns the storage and orchestrates
351/// the per-layer call sequence.
352pub struct Qwen3MoeModel<B: Backend> {
353    pub cfg: Qwen3MoeConfig,
354    pub runtime_cfg: LlmRuntimeConfig,
355
356    pub embed: B::Buffer,
357    /// Per-layer attention weights (re-uses dense `LlamaFamilyLayer`).
358    pub attn_layers: Vec<LlamaFamilyLayer<B>>,
359    /// Per-layer MoE state (router + expert stack).
360    pub moe_layers: Vec<Qwen3MoeLayerState<B>>,
361    pub final_norm_w: B::Buffer,
362    pub lm_head: Box<dyn ferrum_quantization::Linear<B>>,
363
364    pub rope: RopeCache<B>,
365    pub scratch: Qwen3MoeScratch<B>,
366
367    pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
368    kv_free_pool: Vec<Vec<KvCache<B>>>,
369
370    // ── Paged-KV multi-seq state ────────────────────────────────────────
371    //
372    // Mirrors `LlamaFamilyModel`. Only populated when
373    // `FERRUM_METAL_PAGED_KV=1`. Kv_caches entries become metadata-only
374    // views (block_table + context_lens) into the shared `paged_pools`.
375    pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
376    pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
377}
378
379impl<B: Backend> Qwen3MoeModel<B> {
380    /// Build a Qwen3-MoE model from a generic `WeightLoader<B>` plus a
381    /// GGUF reader for the experts (which `WeightLoader` doesn't model
382    /// directly — its API is rank-2 only).
383    ///
384    /// `loader` provides: token embedding, attention projections, layer
385    /// norms, lm_head — all the rank-2 weights.
386    /// `gguf` provides: the rank-3 expert tensors, sliced per-expert
387    /// inside [`ExpertStack::load_from_gguf`].
388    pub fn new(
389        cfg: Qwen3MoeConfig,
390        loader: &dyn WeightLoader<B>,
391        gguf: &ferrum_quantization::gguf::GgufFile,
392    ) -> Result<Self> {
393        {
394            let mut ctx = B::new_context();
395            B::reset_graph(&mut ctx);
396        }
397        let rope = build_rope_cache::<B>(&cfg.base);
398        let scratch = Qwen3MoeScratch::alloc(&cfg, 1);
399
400        let embed = loader.load_tensor("model.embed_tokens.weight")?;
401
402        let mut attn_layers = Vec::with_capacity(cfg.base.num_layers);
403        let mut moe_layers = Vec::with_capacity(cfg.base.num_layers);
404        for li in 0..cfg.base.num_layers {
405            let prefix = format!("model.layers.{li}");
406            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
407            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
408            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
409            let post_ln_w =
410                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
411
412            // Dense gate_up_proj / down_proj are absent in MoE GGUFs —
413            // we synthesise stub Linears so the LlamaFamilyLayer struct
414            // type-checks. They're never invoked because forward_layer
415            // calls the MoE path. Cheap: tiny zero-sized DenseLinears.
416            let gate_up_proj: Box<dyn ferrum_quantization::Linear<B>> =
417                stub_linear::<B>(2 * cfg.expert_intermediate_size, cfg.base.hidden_size);
418            let down_proj: Box<dyn ferrum_quantization::Linear<B>> =
419                stub_linear::<B>(cfg.base.hidden_size, cfg.expert_intermediate_size);
420
421            let (q_norm_w, k_norm_w) = if cfg.base.has_qk_norm {
422                let q = loader
423                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
424                    .ok();
425                let k = loader
426                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
427                    .ok();
428                (q, k)
429            } else {
430                (None, None)
431            };
432
433            attn_layers.push(LlamaFamilyLayer {
434                input_ln_w,
435                qkv_proj,
436                q_norm_w,
437                k_norm_w,
438                o_proj,
439                post_ln_w,
440                gate_up_proj,
441                down_proj,
442            });
443
444            // Router lives at `model.layers.{li}.mlp.router.weight` in
445            // ferrum-name space (see ferrum_to_gguf mapping). It's a
446            // plain rank-2 linear so the standard loader path covers
447            // it without going through the MoE-specific GGUF helper.
448            let router = loader.load_linear(&format!("{prefix}.mlp.router"))?;
449            if router.in_features() != cfg.base.hidden_size {
450                return Err(FerrumError::model(format!(
451                    "router layer {li}: in_features {} != hidden {}",
452                    router.in_features(),
453                    cfg.base.hidden_size
454                )));
455            }
456            if router.out_features() != cfg.num_experts {
457                return Err(FerrumError::model(format!(
458                    "router layer {li}: out_features {} != num_experts {}",
459                    router.out_features(),
460                    cfg.num_experts
461                )));
462            }
463
464            let experts = ExpertStack::<B>::load_from_gguf(
465                gguf,
466                li,
467                cfg.num_experts,
468                cfg.base.hidden_size,
469                cfg.expert_intermediate_size,
470            )?;
471
472            moe_layers.push(Qwen3MoeLayerState { router, experts });
473        }
474
475        let final_norm_w = loader.load_tensor("model.norm.weight")?;
476        let lm_head = if loader.has_tensor("lm_head.weight") {
477            loader.load_linear("lm_head")?
478        } else {
479            // Tied embeddings — same as dense path.
480            tracing::info!(
481                "Qwen3MoeModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
482            );
483            loader.load_linear("model.embed_tokens")?
484        };
485
486        let runtime_cfg = cfg.base.to_runtime();
487        Ok(Self {
488            cfg,
489            runtime_cfg,
490            embed,
491            attn_layers,
492            moe_layers,
493            final_norm_w,
494            lm_head,
495            rope,
496            scratch,
497            kv_caches: HashMap::new(),
498            kv_free_pool: Vec::new(),
499            paged_pools: None,
500            paged_block_alloc: None,
501        })
502    }
503
504    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
505        if self.scratch.max_tokens < tokens {
506            {
507                let mut ctx = B::new_context();
508                B::reset_graph(&mut ctx);
509            }
510            self.scratch = Qwen3MoeScratch::alloc(&self.cfg, tokens);
511        }
512    }
513
514    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
515        if self.kv_caches.contains_key(cache_id) {
516            return;
517        }
518        let nkv = self.cfg.base.num_kv_heads;
519        let hd = self.cfg.base.head_dim;
520        // See `LlamaFamilyModel::ensure_kv` for the rationale on the 4096
521        // chat-friendly default and how `FERRUM_KV_CAPACITY` overrides it.
522        let model_max = self.cfg.base.max_seq_len;
523        const DEFAULT_KV_CAPACITY: usize = 4096;
524        let max = std::env::var("FERRUM_KV_CAPACITY")
525            .ok()
526            .and_then(|s| s.parse::<usize>().ok())
527            .map(|cap| cap.min(model_max))
528            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
529
530        // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches caches into
531        // block-table-indirect layout. Mirrors LlamaFamilyModel's path so
532        // the existing `paged_decode_attention` Metal kernel can fire
533        // once at num_seqs=m for batched decode (replacing the per-item
534        // attention loop that currently dominates `attn_peritem` in the
535        // c=16 profile).
536        let paged = std::env::var("FERRUM_METAL_PAGED_KV")
537            .map(|v| v == "1")
538            .unwrap_or(false);
539        const PAGED_BLOCK_SIZE: usize = 16;
540
541        let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
542            .ok()
543            .and_then(|s| s.parse::<usize>().ok())
544            .unwrap_or(16);
545        let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
546        let total_pool_blocks = max_seqs * max_blocks_per_seq;
547
548        // Lazy-allocate the shared paged pools on the first paged
549        // ensure_kv call.
550        if paged && self.paged_pools.is_none() {
551            let mut pools = Vec::with_capacity(self.cfg.base.num_layers);
552            for _ in 0..self.cfg.base.num_layers {
553                let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
554                pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
555            }
556            self.paged_pools = Some(pools);
557            self.paged_block_alloc = Some(std::sync::Mutex::new(
558                crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
559            ));
560        }
561        if paged {
562            self.scratch
563                .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
564        }
565
566        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
567            (0..self.cfg.base.num_layers)
568                .map(|_| {
569                    if paged {
570                        // Paged mode: cache holds metadata only. K/V are
571                        // 1-element placeholders. Real data lives in
572                        // `self.paged_pools[li].{k,v}`.
573                        let mut block_table = B::alloc_u32(max_blocks_per_seq);
574                        let _ = &mut block_table; // suppress unused-mut on backends that no-op write_u32
575                        let mut context_lens = B::alloc_u32(1);
576                        let mut bt_ctx = B::new_context();
577                        B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
578                        B::sync(&mut bt_ctx);
579                        KvCache {
580                            k: B::alloc(1),
581                            v: B::alloc(1),
582                            len: 0,
583                            capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
584                            num_kv_heads: nkv,
585                            head_dim: hd,
586                            block_size: PAGED_BLOCK_SIZE,
587                            block_table: Some(block_table),
588                            context_lens: Some(context_lens),
589                            paged_block_indices: Vec::new(),
590                        }
591                    } else {
592                        KvCache {
593                            k: B::alloc(nkv * max * hd),
594                            v: B::alloc(nkv * max * hd),
595                            len: 0,
596                            capacity: max,
597                            num_kv_heads: nkv,
598                            head_dim: hd,
599                            block_size: 0,
600                            block_table: None,
601                            context_lens: None,
602                            paged_block_indices: Vec::new(),
603                        }
604                    }
605                })
606                .collect()
607        });
608
609        // Allocate physical blocks for THIS cache_id from the shared pool.
610        if paged {
611            let alloc_arc = self
612                .paged_block_alloc
613                .as_ref()
614                .expect("paged_block_alloc must be initialised when paged=true");
615            let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
616            let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
617                Ok(idx) => idx,
618                Err(e) => {
619                    drop(alloc);
620                    self.kv_free_pool.push(caches);
621                    eprintln!(
622                        "[ferrum] paged KV pool exhausted on ensure_kv for \
623                         cache_id={cache_id:?}: {e}. Increase \
624                         FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
625                         throttle concurrent requests.",
626                    );
627                    return;
628                }
629            };
630            let mut padded = block_indices.clone();
631            padded.resize(max_blocks_per_seq, 0);
632            let mut ctx_tmp = B::new_context();
633            for c in caches.iter_mut() {
634                if let Some(bt) = c.block_table.as_mut() {
635                    B::write_u32(&mut ctx_tmp, bt, &padded);
636                }
637                c.paged_block_indices = block_indices.clone();
638            }
639            B::sync(&mut ctx_tmp);
640        }
641
642        for c in caches.iter_mut() {
643            c.len = 0;
644            if let Some(cl) = c.context_lens.as_mut() {
645                let mut ctx_tmp = B::new_context();
646                B::write_u32(&mut ctx_tmp, cl, &[0u32]);
647                B::sync(&mut ctx_tmp);
648            }
649        }
650        self.kv_caches.insert(cache_id.to_string(), caches);
651    }
652
653    /// Run one full transformer layer (attention + MoE FFN).
654    pub(crate) fn forward_layer(
655        &mut self,
656        ctx: &mut B::Context,
657        li: usize,
658        cache_id: &str,
659        residual: &mut B::Buffer,
660        pos_offset: usize,
661        tokens: usize,
662        // If `Some(idx)` and we land on the decode fast path, fold the
663        // next layer's leading rms_norm into this layer's MoE tail
664        // (cross-layer norm fusion). The next layer's caller must pass
665        // `prev_did_norm_fusion = true` so it skips its own rms_norm.
666        next_layer_idx: Option<usize>,
667        // If `true`, skip step 1's input rms_norm — the previous
668        // layer's tail already populated `scratch.norm_out`.
669        prev_did_norm_fusion: bool,
670    ) -> Result<bool> {
671        let cfg_base = &self.cfg.base;
672        let h = cfg_base.hidden_size;
673        let nh = cfg_base.num_heads;
674        let nkv = cfg_base.num_kv_heads;
675        let hd = cfg_base.head_dim;
676        let eps = cfg_base.rms_norm_eps;
677        let q_dim = nh * hd;
678        let kv_dim = nkv * hd;
679        let attn_layer = &self.attn_layers[li];
680        let moe_layer = &self.moe_layers[li];
681
682        let attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
683            B::sync(ctx);
684            Some(std::time::Instant::now())
685        } else {
686            None
687        };
688
689        // 1. Input RMSNorm — skipped when the previous layer's MoE tail
690        //    fused this norm via `weighted_sum_residual_norm_stacked`.
691        if !prev_did_norm_fusion {
692            B::rms_norm(
693                ctx,
694                residual,
695                &attn_layer.input_ln_w,
696                eps,
697                &mut self.scratch.norm_out,
698                tokens,
699                h,
700            );
701        }
702
703        // 2. Fused QKV
704        attn_layer.qkv_proj.forward(
705            ctx,
706            &self.scratch.norm_out,
707            &mut self.scratch.qkv_out,
708            tokens,
709        );
710
711        // 3-4. Fused split-QKV + QK-norm + RoPE + head-major transpose.
712        //
713        // One Metal dispatch replaces (split_qkv → 3× qk_norm_rope), the
714        // four-launch chain that used to dominate the attention prelude.
715        // Reads qkv_out once, writes head-major Q/K (norm+RoPE) and V
716        // (transpose only) directly into attention scratch. Saves 3
717        // dispatches per layer (×48 = 144 dispatches per decode token).
718        //
719        // CPU and other backends without the fused kernel return
720        // Unsupported and we fall through to the original four-launch
721        // path. q_buf / k_buf / v_buf stay in scratch because that path
722        // and the per-expert MoE fallback still want them.
723        let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
724        let dummy = &attn_layer.input_ln_w;
725        let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy);
726        let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy);
727
728        // 5. Grab the per-layer KV cache up front — the deepest fused
729        //    variant writes K/V straight into it, avoiding a trailing
730        //    `kv_cache_append_head_major` dispatch.
731        //
732        // Paged mode: extract a raw pointer to the layer's pool buffers
733        // BEFORE the &mut cache borrow, so we can pass &mut to the
734        // paged kernel below without holding two simultaneous mutable
735        // borrows on `self`. Safety: `paged_pools` is allocated once at
736        // first ensure_kv call and never resized; the only concurrent
737        // mutation is the pool's own kernel writes (sequenced via
738        // command buffers), so the raw pointer remains valid for the
739        // duration of this layer call.
740        let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
741            if let Some(pools) = self.paged_pools.as_mut() {
742                let pool = &mut pools[li];
743                Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
744            } else {
745                None
746            };
747        let caches = self
748            .kv_caches
749            .get_mut(cache_id)
750            .expect("ensure_kv must be called before forward_layer");
751        let cache = &mut caches[li];
752        let cache_len_before = cache.len;
753        let cache_capacity = cache.capacity;
754
755        // Defense in depth: refuse to write past the KV buffer. Silent
756        // overflow has visible failure modes (garbage output, stale token
757        // attention, slowdowns from reading uninitialised memory). The
758        // graceful path is the caller pre-checking via `kv_capacity()` and
759        // either compacting or refusing the request; this panic only
760        // fires when that contract is broken.
761        if cache_len_before + tokens > cache_capacity {
762            panic!(
763                "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
764                cache_len_before + tokens
765            );
766        }
767
768        // Try the deepest fusion: fused split-QKV-norm-rope that writes
769        // K/V directly into the cache slot. Paged mode writes into the
770        // shared pool via block_table indirection; contiguous mode
771        // writes into the per-cache_id k/v buffers directly.
772        let used_qkv_into_cache = if cache.block_size > 0 {
773            // Paged path.
774            let bt = cache
775                .block_table
776                .as_ref()
777                .expect("paged cache missing block_table");
778            let num_blocks_per_seq = cache.capacity / cache.block_size;
779            let (pool_k_ptr, pool_v_ptr) =
780                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
781            // SAFETY: pools allocated-once, see paged_pool_ptr setup above.
782            let pool_k = unsafe { &mut *pool_k_ptr };
783            let pool_v = unsafe { &mut *pool_v_ptr };
784            B::split_qkv_norm_rope_into_paged_cache(
785                ctx,
786                &self.scratch.qkv_out,
787                0,
788                q_norm_w,
789                k_norm_w,
790                &self.rope.cos,
791                &self.rope.sin,
792                &mut self.scratch.q_head_major,
793                0,
794                pool_k,
795                pool_v,
796                bt,
797                tokens,
798                nh,
799                nkv,
800                hd,
801                pos_offset,
802                eps,
803                qk_mode,
804                cache_len_before,
805                cache.block_size,
806                num_blocks_per_seq,
807            )
808            .is_ok()
809        } else {
810            B::split_qkv_norm_rope_into_cache(
811                ctx,
812                &self.scratch.qkv_out,
813                q_norm_w,
814                k_norm_w,
815                &self.rope.cos,
816                &self.rope.sin,
817                &mut self.scratch.q_head_major,
818                &mut cache.k,
819                &mut cache.v,
820                tokens,
821                nh,
822                nkv,
823                hd,
824                pos_offset,
825                eps,
826                qk_mode,
827                cache_len_before,
828                cache_capacity,
829            )
830            .is_ok()
831        };
832        if !used_qkv_into_cache {
833            // Fallback 1: fused split-QKV-norm-rope to head-major scratch
834            // (Metal pre-decode-fusion path), then explicit cache append.
835            let used_fused_qkv = B::split_qkv_norm_rope(
836                ctx,
837                &self.scratch.qkv_out,
838                q_norm_w,
839                k_norm_w,
840                &self.rope.cos,
841                &self.rope.sin,
842                &mut self.scratch.q_head_major,
843                &mut self.scratch.k_head_major,
844                &mut self.scratch.v_head_major,
845                tokens,
846                nh,
847                nkv,
848                hd,
849                pos_offset,
850                eps,
851                qk_mode,
852            )
853            .is_ok();
854            if !used_fused_qkv {
855                // Fallback 2: original four-launch chain.
856                B::split_qkv(
857                    ctx,
858                    &self.scratch.qkv_out,
859                    &mut self.scratch.q_buf,
860                    &mut self.scratch.k_buf,
861                    &mut self.scratch.v_buf,
862                    tokens,
863                    q_dim,
864                    kv_dim,
865                );
866                B::qk_norm_rope(
867                    ctx,
868                    &self.scratch.q_buf,
869                    q_norm_w,
870                    &self.rope.cos,
871                    &self.rope.sin,
872                    &mut self.scratch.q_head_major,
873                    tokens,
874                    nh,
875                    hd,
876                    pos_offset,
877                    eps,
878                    qk_mode,
879                );
880                B::qk_norm_rope(
881                    ctx,
882                    &self.scratch.k_buf,
883                    k_norm_w,
884                    &self.rope.cos,
885                    &self.rope.sin,
886                    &mut self.scratch.k_head_major,
887                    tokens,
888                    nkv,
889                    hd,
890                    pos_offset,
891                    eps,
892                    qk_mode,
893                );
894                B::qk_norm_rope(
895                    ctx,
896                    &self.scratch.v_buf,
897                    dummy,
898                    &self.rope.cos,
899                    &self.rope.sin,
900                    &mut self.scratch.v_head_major,
901                    tokens,
902                    nkv,
903                    hd,
904                    pos_offset,
905                    eps,
906                    0,
907                );
908            }
909            B::kv_cache_append_head_major(
910                ctx,
911                &mut cache.k,
912                &mut cache.v,
913                cache.len,
914                cache.capacity,
915                &self.scratch.k_head_major,
916                &self.scratch.v_head_major,
917                tokens,
918                nkv,
919                hd,
920            );
921        }
922        cache.len += tokens;
923        let kv_len = cache.len;
924        let kv_stride = cache.capacity;
925
926        if cache.block_size > 0 {
927            // Paged decode: read from the shared pool via block_table.
928            let bt = cache
929                .block_table
930                .as_ref()
931                .expect("paged cache missing block_table");
932            let cl_buf = cache
933                .context_lens
934                .as_mut()
935                .expect("paged cache missing context_lens");
936            let num_blocks_per_seq = cache.capacity / cache.block_size;
937            let (pool_k_ptr, pool_v_ptr) =
938                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
939            // SAFETY: see paged_pool_ptr setup above.
940            let pool_k = unsafe { &*pool_k_ptr };
941            let pool_v = unsafe { &*pool_v_ptr };
942            let final_kv_len = cache.len as u32;
943            B::write_u32(ctx, cl_buf, &[final_kv_len]);
944            B::paged_decode_attention(
945                ctx,
946                &self.scratch.q_head_major,
947                pool_k,
948                pool_v,
949                &mut self.scratch.attn_head_major_out,
950                bt,
951                cl_buf,
952                1, // num_seqs (single-seq m=1 path)
953                nh,
954                nkv,
955                hd,
956                cache.block_size,
957                num_blocks_per_seq,
958                tokens,
959            )
960            .expect("paged_decode_attention");
961            let _ = kv_stride; // consumed by contig path only
962        } else {
963            let attn_cfg = ferrum_kernels::backend::AttnConfig {
964                num_heads: nh,
965                num_kv_heads: nkv,
966                head_dim: hd,
967                causal: true,
968                scale: 1.0 / (hd as f32).sqrt(),
969                kv_seq_stride: kv_stride,
970                sliding_window: cfg_base.sliding_window,
971            };
972            B::flash_attention(
973                ctx,
974                &self.scratch.q_head_major,
975                &cache.k,
976                &cache.v,
977                &mut self.scratch.attn_head_major_out,
978                1,
979                tokens,
980                kv_len,
981                pos_offset,
982                &attn_cfg,
983            );
984        }
985
986        if let Some(t0) = attn_t0 {
987            B::sync(ctx);
988            ATTN_TIME_US.fetch_add(
989                t0.elapsed().as_micros() as u64,
990                std::sync::atomic::Ordering::Relaxed,
991            );
992            ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
993        }
994
995        // 7. transpose head-major → token-major.
996        //
997        // For tokens=1 the two layouts are byte-identical: both
998        // collapse to the flat [heads * head_dim] vector at offset
999        // `head*hd + d`. Skip the dispatch and point o_proj at
1000        // attn_head_major_out directly. Saves 1 dispatch per layer
1001        // (×48 = 48 dispatches per decode token) on Qwen3-30B-A3B.
1002        let attn_token_major = if tokens == 1 {
1003            &self.scratch.attn_head_major_out
1004        } else {
1005            B::transpose_head_to_token(
1006                ctx,
1007                &self.scratch.attn_head_major_out,
1008                &mut self.scratch.attn_flat,
1009                tokens,
1010                nh,
1011                hd,
1012            );
1013            &self.scratch.attn_flat
1014        };
1015
1016        // 8. O-proj.
1017        attn_layer
1018            .o_proj
1019            .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1020
1021        // 9. fused residual-add + post-attention RMSNorm.
1022        B::fused_add_rms_norm(
1023            ctx,
1024            residual,
1025            &self.scratch.o_proj_out,
1026            &attn_layer.post_ln_w,
1027            eps,
1028            &mut self.scratch.norm_out,
1029            tokens,
1030            h,
1031        );
1032
1033        // ── MoE FFN block ────────────────────────────────────────────
1034        let moe_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1035            B::sync(ctx);
1036            Some(std::time::Instant::now())
1037        } else {
1038            None
1039        };
1040
1041        // 10. Router gemv: norm_out [tokens, hidden] → router_logits [tokens, num_experts]
1042        moe_layer.router.forward(
1043            ctx,
1044            &self.scratch.norm_out,
1045            &mut self.scratch.router_logits,
1046            tokens,
1047        );
1048
1049        // 11. Per-(token, expert) MLP dispatch + weighted combine.
1050        //
1051        // Two paths:
1052        //   - **Batched fast path** (decode m=1, all stacked variants
1053        //     present): single `gemv_quant_moe_id` dispatch covers all
1054        //     8 selected expert × 1 token gate gemvs in parallel; same
1055        //     for up and down. Cuts per-layer expert dispatches from
1056        //     ~32 (8 × 4 ops/pair) to 4 (gate + up + silu + down + 1 acc).
1057        //     Routes Qwen3-30B-A3B decode close to llama.cpp's
1058        //     `kernel_mul_mm_id`.
1059        //   - **Per-(token, expert) fallback** via `moe_forward` —
1060        //     used for prefill (m > 1), or when the backend doesn't
1061        //     populate stacked variants (CPU, synthetic-MoE tests).
1062        let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
1063            && moe_layer.experts.up_stacked.is_some()
1064            && moe_layer.experts.down_stacked.is_some();
1065
1066        // Fast path for decode (tokens=1): the stacked decode impl
1067        // writes the weighted-sum result *directly* into `residual` via
1068        // `weighted_sum_residual_stacked`, skipping the moe_out scratch
1069        // and the trailing `add_inplace`. Saves 1 dispatch per layer.
1070        // Prefill (m>1) and the per-expert fallback still go through
1071        // moe_out + add_inplace.
1072        let decode_fast_path = stacked_path_available && tokens == 1;
1073        // Cross-layer fusion: when on the decode fast path AND there is
1074        // a next layer, fold its leading rms_norm into this layer's
1075        // tail (`weighted_sum_residual_norm_stacked`). Returns whether
1076        // the fusion ran so the caller can signal the next layer to
1077        // skip its standalone rms_norm.
1078        let did_norm_fusion = decode_fast_path && next_layer_idx.is_some();
1079
1080        if stacked_path_available {
1081            if tokens > 1 {
1082                // Prefill: one batched 2-D mul_mm_id covers all
1083                // (token, expert) pairs in parallel.
1084                self.moe_forward_batched_prefill(ctx, li, tokens)?;
1085            } else {
1086                // Decode m=1: dedicated per-token path that fuses
1087                // residual-add into the final weighted-sum, and
1088                // optionally folds the next layer's rms_norm in too.
1089                self.moe_forward_stacked(ctx, li, tokens, residual, next_layer_idx)?;
1090            }
1091        } else {
1092            moe_forward::<B>(
1093                ctx,
1094                &self.scratch.norm_out,
1095                &self.scratch.router_logits,
1096                &mut self.scratch.moe_out,
1097                tokens,
1098                h,
1099                self.cfg.expert_intermediate_size,
1100                self.cfg.num_experts,
1101                self.cfg.num_experts_per_tok,
1102                self.cfg.norm_topk_prob,
1103                &moe_layer.experts,
1104                &mut self.scratch.x_single,
1105                &mut self.scratch.acc_buf,
1106                &mut self.scratch.gate_up_buf,
1107                &mut self.scratch.silu_buf,
1108                &mut self.scratch.down_buf,
1109                &self.scratch.zero_hidden,
1110            )?;
1111        }
1112
1113        // 12. residual += moe_out (skipped on decode fast path — already
1114        //     accumulated by `weighted_sum_residual_stacked`).
1115        if !decode_fast_path {
1116            B::add_inplace(ctx, residual, &self.scratch.moe_out, tokens * h);
1117        }
1118
1119        if let Some(t0) = moe_t0 {
1120            B::sync(ctx);
1121            MOE_TIME_US.fetch_add(
1122                t0.elapsed().as_micros() as u64,
1123                std::sync::atomic::Ordering::Relaxed,
1124            );
1125            MOE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1126        }
1127
1128        Ok(did_norm_fusion)
1129    }
1130
1131    fn moe_forward_stacked(
1132        &mut self,
1133        ctx: &mut B::Context,
1134        li: usize,
1135        tokens: usize,
1136        residual: &mut B::Buffer,
1137        next_layer_idx: Option<usize>,
1138    ) -> Result<()> {
1139        let cfg = &self.cfg;
1140        // `next_norm_w` is the next layer's `attn_layer.input_ln_w`.
1141        // We can't borrow `self.attn_layers[idx]` and pass &mut
1142        // self.scratch to the impl simultaneously, so collect the raw
1143        // pointer here. Safety: forward_layer holds &mut self for the
1144        // call; the borrow scopes are fully sequential.
1145        let next_norm_w_ptr: Option<*const B::Buffer> =
1146            next_layer_idx.map(|idx| &self.attn_layers[idx].input_ln_w as *const _);
1147        // SAFETY: pointer dereference is valid because:
1148        //   * The buffer lives in `self.attn_layers[idx]` which we
1149        //     borrowed immutably to take the pointer. We do not mutate
1150        //     `self.attn_layers` while `next_norm_w_ptr` is in use.
1151        //   * `&mut self.scratch` and `&self.moe_layers[li]` are disjoint
1152        //     fields from `self.attn_layers` so this is safe.
1153        let next_norm_w: Option<&B::Buffer> = next_norm_w_ptr.map(|p| unsafe { &*p });
1154        moe_forward_stacked_decode_impl::<B>(
1155            ctx,
1156            &self.moe_layers[li],
1157            &mut self.scratch,
1158            cfg.base.hidden_size,
1159            cfg.expert_intermediate_size,
1160            cfg.num_experts_per_tok,
1161            cfg.num_experts,
1162            cfg.norm_topk_prob,
1163            tokens,
1164            residual,
1165            next_norm_w,
1166            cfg.base.rms_norm_eps,
1167        )
1168    }
1169
1170    fn moe_forward_batched_prefill(
1171        &mut self,
1172        ctx: &mut B::Context,
1173        li: usize,
1174        tokens: usize,
1175    ) -> Result<()> {
1176        let cfg = &self.cfg;
1177        moe_forward_batched_prefill_impl::<B>(
1178            ctx,
1179            &self.moe_layers[li],
1180            &mut self.scratch,
1181            cfg.base.hidden_size,
1182            cfg.expert_intermediate_size,
1183            cfg.num_experts_per_tok,
1184            cfg.num_experts,
1185            cfg.norm_topk_prob,
1186            tokens,
1187        )
1188    }
1189
1190    /// Prefill: process `tokens` prompt tokens, return last-token logits.
1191    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1192        let seq_len = tokens.len();
1193        assert!(seq_len > 0);
1194        self.ensure_scratch(seq_len);
1195        self.ensure_kv(cache_id);
1196
1197        let pos_offset = self
1198            .kv_caches
1199            .get(cache_id)
1200            .and_then(|layers| layers.first())
1201            .map(|c| c.len)
1202            .unwrap_or(0);
1203
1204        let h = self.cfg.base.hidden_size;
1205        let vocab = self.cfg.base.vocab_size;
1206        let mut ctx = B::new_context();
1207
1208        // FERRUM_DECODE_OP_PROFILE doubles as the prefill-profile gate
1209        // for Qwen3-MoE: when set, dump (attn-us, moe-us, total-us) at
1210        // the end of prefill so we can attribute the prefill bottleneck
1211        // between attention and MoE.
1212        let prefill_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1213            B::sync(&mut ctx);
1214            for c in [
1215                &ATTN_TIME_US,
1216                &ATTN_CALLS,
1217                &MOE_TIME_US,
1218                &MOE_CALLS,
1219                &MOE_PREFILL_HOST_TOPK_US,
1220                &MOE_PREFILL_HOST_TOPK_CALLS,
1221                &MOE_PREFILL_GATE_US,
1222                &MOE_PREFILL_GATE_CALLS,
1223                &MOE_PREFILL_UP_US,
1224                &MOE_PREFILL_UP_CALLS,
1225                &MOE_PREFILL_SILU_US,
1226                &MOE_PREFILL_SILU_CALLS,
1227                &MOE_PREFILL_DOWN_US,
1228                &MOE_PREFILL_DOWN_CALLS,
1229                &MOE_PREFILL_WSUM_US,
1230                &MOE_PREFILL_WSUM_CALLS,
1231            ] {
1232                c.store(0, std::sync::atomic::Ordering::Relaxed);
1233            }
1234            Some(std::time::Instant::now())
1235        } else {
1236            None
1237        };
1238
1239        let mut residual = self
1240            .scratch
1241            .residual
1242            .take()
1243            .expect("scratch residual missing (previous call didn't restore)");
1244        B::embedding_lookup(&mut ctx, &self.embed, tokens, &mut residual, h);
1245
1246        // For prefill (seq_len > 1) the cross-layer norm fusion does
1247        // not apply (it lives on the decode fast path). We still pass
1248        // `next_layer_idx = None` so forward_layer emits the regular
1249        // tail.
1250        let mut prev_did_norm_fusion = false;
1251        let num_layers = self.cfg.base.num_layers;
1252        for li in 0..num_layers {
1253            let next_layer_idx = if li + 1 < num_layers {
1254                Some(li + 1)
1255            } else {
1256                None
1257            };
1258            prev_did_norm_fusion = self
1259                .forward_layer(
1260                    &mut ctx,
1261                    li,
1262                    cache_id,
1263                    &mut residual,
1264                    pos_offset,
1265                    seq_len,
1266                    next_layer_idx,
1267                    prev_did_norm_fusion,
1268                )
1269                .expect("forward_layer");
1270        }
1271
1272        // Last-token slice → final RMSNorm → lm_head.
1273        B::copy_slice(
1274            &mut ctx,
1275            &residual,
1276            (seq_len - 1) * h,
1277            &mut self.scratch.last_hidden,
1278            0,
1279            h,
1280        );
1281        B::rms_norm(
1282            &mut ctx,
1283            &self.scratch.last_hidden,
1284            &self.final_norm_w,
1285            self.cfg.base.rms_norm_eps,
1286            &mut self.scratch.last_normed,
1287            1,
1288            h,
1289        );
1290        self.lm_head.forward(
1291            &mut ctx,
1292            &self.scratch.last_normed,
1293            &mut self.scratch.logits,
1294            1,
1295        );
1296
1297        B::sync(&mut ctx);
1298        if let Some(t0) = prefill_t0 {
1299            let total_us = t0.elapsed().as_micros() as u64;
1300            let attn_us = ATTN_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
1301            let attn_n = ATTN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1302            let moe_us = MOE_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
1303            let moe_n = MOE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1304            let other_us = total_us.saturating_sub(attn_us).saturating_sub(moe_us);
1305            eprintln!(
1306                "[prefill-profile] tokens={seq_len} total={} ms ({:.0} t/s)",
1307                total_us / 1000,
1308                seq_len as f64 * 1e6 / total_us as f64
1309            );
1310            let bucket = |label: &str, n: u64, us: u64| {
1311                if n > 0 {
1312                    eprintln!(
1313                        "  {label:>6}: {:7} ms ({:5.1}%) over {n:4} calls",
1314                        us / 1000,
1315                        us as f64 * 100.0 / total_us as f64
1316                    );
1317                }
1318            };
1319            bucket("attn", attn_n, attn_us);
1320            bucket("moe", moe_n, moe_us);
1321            bucket("other", 1, other_us);
1322            // MoE sub-stages — show as % of total prefill time so they
1323            // reconcile against the `moe` bucket above.
1324            let host_us = MOE_PREFILL_HOST_TOPK_US.load(std::sync::atomic::Ordering::Relaxed);
1325            let gate_us = MOE_PREFILL_GATE_US.load(std::sync::atomic::Ordering::Relaxed);
1326            let up_us = MOE_PREFILL_UP_US.load(std::sync::atomic::Ordering::Relaxed);
1327            let silu_us = MOE_PREFILL_SILU_US.load(std::sync::atomic::Ordering::Relaxed);
1328            let down_us = MOE_PREFILL_DOWN_US.load(std::sync::atomic::Ordering::Relaxed);
1329            let wsum_us = MOE_PREFILL_WSUM_US.load(std::sync::atomic::Ordering::Relaxed);
1330            let host_n = MOE_PREFILL_HOST_TOPK_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1331            let gate_n = MOE_PREFILL_GATE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1332            let up_n = MOE_PREFILL_UP_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1333            let silu_n = MOE_PREFILL_SILU_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1334            let down_n = MOE_PREFILL_DOWN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1335            let wsum_n = MOE_PREFILL_WSUM_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1336            bucket("  host", host_n, host_us);
1337            bucket("  gate", gate_n, gate_us);
1338            bucket("  up", up_n, up_us);
1339            bucket("  silu", silu_n, silu_us);
1340            bucket("  down", down_n, down_us);
1341            bucket("  wsum", wsum_n, wsum_us);
1342        }
1343        self.scratch.residual = Some(residual);
1344        B::to_vec(&self.scratch.logits, vocab)
1345    }
1346
1347    /// Decode: 1 token at position `pos`, return next-step logits.
1348    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1349        self.ensure_scratch(1);
1350        self.ensure_kv(cache_id);
1351
1352        let h = self.cfg.base.hidden_size;
1353        let vocab = self.cfg.base.vocab_size;
1354        let mut ctx = B::new_context();
1355
1356        let decode_t0 = if std::env::var("FERRUM_MOE_PROFILE").is_ok() {
1357            Some(std::time::Instant::now())
1358        } else {
1359            None
1360        };
1361
1362        // FERRUM_DECODE_OP_PROFILE gates the per-stage breakdown emitted
1363        // at the bottom of every decode token. Reuses the same atomic
1364        // counters that `forward_layer` already populates (ATTN_TIME_US,
1365        // MOE_TIME_US — drained here per-token instead of per-prefill).
1366        let stage_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1367            B::sync(&mut ctx);
1368            for c in [
1369                &ATTN_TIME_US,
1370                &ATTN_CALLS,
1371                &MOE_TIME_US,
1372                &MOE_CALLS,
1373                &DEC_ROUTE_US,
1374                &DEC_GATE_US,
1375                &DEC_UP_US,
1376                &DEC_SILU_US,
1377                &DEC_DOWN_US,
1378                &DEC_WSUM_US,
1379                &DEC_EMBED_US,
1380                &DEC_FINAL_NORM_US,
1381                &DEC_LM_HEAD_US,
1382            ] {
1383                c.store(0, std::sync::atomic::Ordering::Relaxed);
1384            }
1385            Some(std::time::Instant::now())
1386        } else {
1387            None
1388        };
1389        let prof = stage_t0.is_some();
1390        let mark = |ctx: &mut B::Context, c: &AtomicU64, t0: std::time::Instant| {
1391            if prof {
1392                B::sync(ctx);
1393                c.fetch_add(
1394                    t0.elapsed().as_micros() as u64,
1395                    std::sync::atomic::Ordering::Relaxed,
1396                );
1397            }
1398        };
1399        let mt0 = std::time::Instant::now();
1400
1401        let mut residual = self
1402            .scratch
1403            .residual
1404            .take()
1405            .expect("scratch residual missing (previous call didn't restore)");
1406        let t0 = std::time::Instant::now();
1407        B::embedding_lookup(&mut ctx, &self.embed, &[token], &mut residual, h);
1408        mark(&mut ctx, &DEC_EMBED_US, t0);
1409        let _ = mt0; // silence if unused on non-profile builds
1410
1411        // Cross-layer rms_norm fusion: layer L's MoE tail folds the
1412        // next layer's leading rms_norm into its weighted-sum-residual
1413        // when the decode fast path applies. The flag carries forward.
1414        let mut prev_did_norm_fusion = false;
1415        let num_layers = self.cfg.base.num_layers;
1416        for li in 0..num_layers {
1417            let next_layer_idx = if li + 1 < num_layers {
1418                Some(li + 1)
1419            } else {
1420                None
1421            };
1422            prev_did_norm_fusion = self
1423                .forward_layer(
1424                    &mut ctx,
1425                    li,
1426                    cache_id,
1427                    &mut residual,
1428                    pos as usize,
1429                    1,
1430                    next_layer_idx,
1431                    prev_did_norm_fusion,
1432                )
1433                .expect("forward_layer");
1434        }
1435
1436        let t0 = std::time::Instant::now();
1437        B::rms_norm(
1438            &mut ctx,
1439            &residual,
1440            &self.final_norm_w,
1441            self.cfg.base.rms_norm_eps,
1442            &mut self.scratch.last_normed,
1443            1,
1444            h,
1445        );
1446        mark(&mut ctx, &DEC_FINAL_NORM_US, t0);
1447
1448        let t0 = std::time::Instant::now();
1449        self.lm_head.forward(
1450            &mut ctx,
1451            &self.scratch.last_normed,
1452            &mut self.scratch.logits,
1453            1,
1454        );
1455        mark(&mut ctx, &DEC_LM_HEAD_US, t0);
1456
1457        B::sync(&mut ctx);
1458        self.scratch.residual = Some(residual);
1459
1460        // FERRUM_DECODE_OP_PROFILE: per-token decode breakdown.
1461        if let Some(t0) = stage_t0 {
1462            use std::sync::atomic::Ordering;
1463            let total_us = t0.elapsed().as_micros() as u64;
1464            let attn_us = ATTN_TIME_US.swap(0, Ordering::Relaxed);
1465            let moe_us = MOE_TIME_US.swap(0, Ordering::Relaxed);
1466            let route = DEC_ROUTE_US.swap(0, Ordering::Relaxed);
1467            let gate = DEC_GATE_US.swap(0, Ordering::Relaxed);
1468            let up = DEC_UP_US.swap(0, Ordering::Relaxed);
1469            let silu = DEC_SILU_US.swap(0, Ordering::Relaxed);
1470            let down = DEC_DOWN_US.swap(0, Ordering::Relaxed);
1471            let wsum = DEC_WSUM_US.swap(0, Ordering::Relaxed);
1472            let embed = DEC_EMBED_US.swap(0, Ordering::Relaxed);
1473            let fnorm = DEC_FINAL_NORM_US.swap(0, Ordering::Relaxed);
1474            let lmhead = DEC_LM_HEAD_US.swap(0, Ordering::Relaxed);
1475            let other = total_us.saturating_sub(attn_us + moe_us + embed + fnorm + lmhead);
1476            let pct = |us: u64| -> f64 {
1477                if total_us == 0 {
1478                    0.0
1479                } else {
1480                    100.0 * us as f64 / total_us as f64
1481                }
1482            };
1483            eprintln!(
1484                "[decode-prof] total={} ms | attn={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | embed={} fnorm={} lmhead={} other={} ({:.1}%)",
1485                total_us / 1000,
1486                attn_us / 1000, pct(attn_us),
1487                moe_us / 1000, pct(moe_us),
1488                route / 1000, gate / 1000, up / 1000, silu / 1000, down / 1000, wsum / 1000,
1489                embed / 1000, fnorm / 1000, lmhead / 1000,
1490                other / 1000, pct(other),
1491            );
1492        }
1493
1494        // Drain MoE per-op counters every decode step. The counters
1495        // accumulate across all 48 layers; printing per-step gives a
1496        // per-token breakdown.
1497        if let Some(t0) = decode_t0 {
1498            use crate::moe::dispatch::*;
1499            use std::sync::atomic::Ordering;
1500            let total_us = t0.elapsed().as_micros() as u64;
1501            let sync_us = MOE_SYNC_US.swap(0, Ordering::Relaxed);
1502            let sync_n = MOE_SYNC_CALLS.swap(0, Ordering::Relaxed);
1503            let topk_us = MOE_HOST_TOPK_US.swap(0, Ordering::Relaxed);
1504            let topk_n = MOE_HOST_TOPK_CALLS.swap(0, Ordering::Relaxed);
1505            let gu_us = MOE_GEMV_GATE_UP_US.swap(0, Ordering::Relaxed);
1506            let gu_n = MOE_GEMV_GATE_UP_CALLS.swap(0, Ordering::Relaxed);
1507            let silu_us = MOE_SILU_US.swap(0, Ordering::Relaxed);
1508            let silu_n = MOE_SILU_CALLS.swap(0, Ordering::Relaxed);
1509            let dn_us = MOE_GEMV_DOWN_US.swap(0, Ordering::Relaxed);
1510            let dn_n = MOE_GEMV_DOWN_CALLS.swap(0, Ordering::Relaxed);
1511            let sa_us = MOE_SCALED_ADD_US.swap(0, Ordering::Relaxed);
1512            let sa_n = MOE_SCALED_ADD_CALLS.swap(0, Ordering::Relaxed);
1513            let cp_us = MOE_COPY_US.swap(0, Ordering::Relaxed);
1514            let cp_n = MOE_COPY_CALLS.swap(0, Ordering::Relaxed);
1515            eprintln!(
1516                "[moe-prof] decode total={} ms | sync={} ms ({}x) | host_topk={} ms ({}x) | gate_up={} ms ({}x) | silu={} ms ({}x) | down={} ms ({}x) | scaled_add={} ms ({}x) | copy={} ms ({}x)",
1517                total_us / 1000,
1518                sync_us / 1000, sync_n,
1519                topk_us / 1000, topk_n,
1520                gu_us / 1000, gu_n,
1521                silu_us / 1000, silu_n,
1522                dn_us / 1000, dn_n,
1523                sa_us / 1000, sa_n,
1524                cp_us / 1000, cp_n,
1525            );
1526        }
1527
1528        B::to_vec(&self.scratch.logits, vocab)
1529    }
1530
1531    /// Multi-sequence batched decode (Phase 4b for MoE).
1532    ///
1533    /// Mirrors `LlamaFamilyModel::decode_batch_internal` but adapted to
1534    /// the MoE forward. The wins come from running the GEMM-heavy ops
1535    /// (qkv_proj, o_proj, router, MoE expert mul_mm_id, lm_head) at
1536    /// m=M, even though attention stays a per-item loop because
1537    /// Qwen3-MoE uses contiguous KV — no paged path here.
1538    ///
1539    /// Cross-layer rms_norm fusion (the `weighted_sum_residual_norm_stacked`
1540    /// fast path) is disabled in batched mode: the prefill MoE path
1541    /// (`moe_forward_batched_prefill_impl`) writes to `moe_out` and we
1542    /// add to residual explicitly. Each layer's leading rms_norm runs
1543    /// at m=M, which is one fused dispatch on M rows — cheap.
1544    pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1545        let m = batch.len();
1546        if m == 0 {
1547            return Vec::new();
1548        }
1549        if m == 1 {
1550            let (cid, tok, pos) = &batch[0];
1551            return vec![self.decode_internal(cid, *tok, *pos)];
1552        }
1553
1554        let prof_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1555            Some(std::time::Instant::now())
1556        } else {
1557            None
1558        };
1559
1560        for (cid, _, _) in batch {
1561            self.ensure_kv(cid);
1562        }
1563        self.ensure_scratch(m);
1564        self.scratch.enable_batched_decode_scratch(&self.cfg);
1565
1566        let h = self.cfg.base.hidden_size;
1567        let vocab = self.cfg.base.vocab_size;
1568        let mut ctx = B::new_context();
1569
1570        // 0. Embed all M tokens into residual [M, H]
1571        let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1572        let mut residual = self
1573            .scratch
1574            .residual
1575            .take()
1576            .expect("scratch residual missing (previous call didn't restore)");
1577        B::embedding_lookup(&mut ctx, &self.embed, &tokens, &mut residual, h);
1578
1579        // 1..num_layers: batched forward for each layer
1580        for li in 0..self.cfg.base.num_layers {
1581            self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m)
1582                .expect("forward_layer_batched_decode");
1583        }
1584
1585        // Final RMSNorm on [M, H] → norm_out [M, H]
1586        B::rms_norm(
1587            &mut ctx,
1588            &residual,
1589            &self.final_norm_w,
1590            self.cfg.base.rms_norm_eps,
1591            &mut self.scratch.norm_out,
1592            m,
1593            h,
1594        );
1595
1596        // LM head with m=M → batch_logits [M, vocab]
1597        self.lm_head.forward(
1598            &mut ctx,
1599            &self.scratch.norm_out,
1600            &mut self.scratch.batch_logits,
1601            m,
1602        );
1603
1604        B::sync(&mut ctx);
1605        self.scratch.residual = Some(residual);
1606
1607        let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
1608
1609        // Profile dump (one decode_batch_internal call = one decode step
1610        // covering all m tokens).
1611        if let Some(t0) = prof_t0 {
1612            use std::sync::atomic::Ordering;
1613            let total_us = t0.elapsed().as_micros() as u64;
1614            let dense = BD_DENSE_US.swap(0, Ordering::Relaxed);
1615            let attn = BD_ATTN_PERITEM_US.swap(0, Ordering::Relaxed);
1616            let moe = BD_MOE_US.swap(0, Ordering::Relaxed);
1617            let layers = BD_LAYER_CALLS.swap(0, Ordering::Relaxed);
1618            let other = total_us.saturating_sub(dense + attn + moe);
1619            let pct = |us: u64| -> f64 {
1620                if total_us == 0 {
1621                    0.0
1622                } else {
1623                    100.0 * us as f64 / total_us as f64
1624                }
1625            };
1626            // MoE sub-stage breakdown — meaningful when
1627            // moe_forward_batched_decode_impl was used.
1628            let moe_route = MOE_BATCHED_DECODE_ROUTE_US.swap(0, Ordering::Relaxed);
1629            let moe_gate = MOE_BATCHED_DECODE_GATE_US.swap(0, Ordering::Relaxed);
1630            let moe_up = MOE_BATCHED_DECODE_UP_US.swap(0, Ordering::Relaxed);
1631            let moe_silu = MOE_BATCHED_DECODE_SILU_US.swap(0, Ordering::Relaxed);
1632            let moe_down = MOE_BATCHED_DECODE_DOWN_US.swap(0, Ordering::Relaxed);
1633            let moe_wsum = MOE_BATCHED_DECODE_WSUM_US.swap(0, Ordering::Relaxed);
1634            eprintln!(
1635                "[batched-decode-prof] m={} layers={} total={} ms | dense={} ({:.1}%) | attn_peritem={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | other={} ({:.1}%)",
1636                m, layers, total_us / 1000,
1637                dense / 1000, pct(dense),
1638                attn / 1000, pct(attn),
1639                moe / 1000, pct(moe),
1640                moe_route / 1000, moe_gate / 1000, moe_up / 1000,
1641                moe_silu / 1000, moe_down / 1000, moe_wsum / 1000,
1642                other / 1000, pct(other),
1643            );
1644        }
1645
1646        (0..m)
1647            .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
1648            .collect()
1649    }
1650
1651    /// One transformer layer over M items: GEMMs at m=M, per-item
1652    /// attention loop, MoE FFN at m=M via the prefill batched path.
1653    /// Mirrors `LlamaFamilyModel::forward_layer_batched_decode` minus
1654    /// the paged branch.
1655    fn forward_layer_batched_decode(
1656        &mut self,
1657        ctx: &mut B::Context,
1658        li: usize,
1659        batch: &[(String, u32, u32)],
1660        residual: &mut B::Buffer,
1661        m: usize,
1662    ) -> Result<()> {
1663        let cfg_base = &self.cfg.base;
1664        let h = cfg_base.hidden_size;
1665        let nh = cfg_base.num_heads;
1666        let nkv = cfg_base.num_kv_heads;
1667        let hd = cfg_base.head_dim;
1668        let eps = cfg_base.rms_norm_eps;
1669        let q_dim = nh * hd;
1670        let kv_dim = nkv * hd;
1671
1672        let attn_layer = &self.attn_layers[li];
1673        let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
1674        let dummy_w = &attn_layer.input_ln_w;
1675        let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy_w);
1676        let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy_w);
1677
1678        let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
1679        let stage_t0 = || -> Option<std::time::Instant> {
1680            if prof {
1681                Some(std::time::Instant::now())
1682            } else {
1683                None
1684            }
1685        };
1686        let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
1687            if let Some(t) = t0 {
1688                B::sync(ctx);
1689                c.fetch_add(
1690                    t.elapsed().as_micros() as u64,
1691                    std::sync::atomic::Ordering::Relaxed,
1692                );
1693            }
1694        };
1695        if prof {
1696            BD_LAYER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1697        }
1698
1699        let dense_t0 = stage_t0();
1700
1701        // 1. rms_norm [M, H] → norm_out
1702        B::rms_norm(
1703            ctx,
1704            residual,
1705            &attn_layer.input_ln_w,
1706            eps,
1707            &mut self.scratch.norm_out,
1708            m,
1709            h,
1710        );
1711
1712        // 2. qkv_proj GEMM at m=M: norm_out [M, H] → qkv_out [M, QKV]
1713        attn_layer
1714            .qkv_proj
1715            .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
1716
1717        // ── Paged batched attention path ───────────────────────────────
1718        //
1719        // Mirrors LlamaFamilyModel's Phase 4b paged batched-decode. When
1720        // `FERRUM_METAL_PAGED_KV=1` was set at ensure_kv time, each
1721        // cache_id has paged metadata (block_table + context_lens) and
1722        // K/V live in the shared `paged_pools[layer]` pool. This path:
1723        //   1. m × `split_qkv_norm_rope_into_paged_cache` writes K/V into
1724        //      the pool at each item's allocated blocks AND fills
1725        //      `paged_batch_q[i*q_dim ..]` with that item's head-major Q.
1726        //   2. Build `paged_batch_block_tables [m, max_blocks_per_seq]`
1727        //      and `paged_batch_context_lens [m]` host-side, upload.
1728        //   3. ONE `paged_decode_attention(num_seqs=m)` call reads all m
1729        //      sequences' K/V from the pool via per-seq block_tables,
1730        //      writes outputs to `paged_batch_o [m, q_dim]`.
1731        //   4. Per-item copy_slice paged_batch_o[i] → attn_flat[i*q_dim].
1732        //
1733        // This is the structural fix for the c=16 attn_peritem cliff
1734        // (~55 ms / round of 16 sequential m=1 flash_attn + plumbing).
1735        let is_paged = self.paged_pools.is_some();
1736        if is_paged {
1737            stage_end(dense_t0, ctx, &BD_DENSE_US);
1738            let attn_t0 = stage_t0();
1739
1740            let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
1741            let block_size = 16; // matches PAGED_BLOCK_SIZE in ensure_kv
1742            let qkv_stride = q_dim + 2 * kv_dim;
1743
1744            // Step 1: per-item paged write. Read each item's qkv out of
1745            // the batched qkv_out buffer; write its head-major Q into
1746            // paged_batch_q[i*q_dim ..]; write its K/V into the pool at
1747            // its allocated blocks via block_table.
1748            let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
1749            let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
1750            let pool_ptr = {
1751                let pools = self.paged_pools.as_mut().unwrap();
1752                (
1753                    &mut pools[li].0 as *mut B::Buffer,
1754                    &mut pools[li].1 as *mut B::Buffer,
1755                )
1756            };
1757            // SAFETY: pools allocated-once, see paged_pools field comment.
1758            let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
1759            for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1760                let pos_i = *pos as usize;
1761                let caches = self
1762                    .kv_caches
1763                    .get(cache_id)
1764                    .expect("paged batched: cache not present");
1765                let cache = &caches[li];
1766                let bt = cache
1767                    .block_table
1768                    .as_ref()
1769                    .expect("paged batched: block_table missing");
1770                let cache_len_before = cache.len;
1771                let bt_ptr = bt as *const B::Buffer;
1772                // SAFETY: bt is read-only during the dispatch; we don't
1773                // touch self.kv_caches between this raw deref and the
1774                // call below.
1775                let bt_safe: &B::Buffer = unsafe { &*bt_ptr };
1776                B::split_qkv_norm_rope_into_paged_cache(
1777                    ctx,
1778                    &self.scratch.qkv_out,
1779                    (i as u64) * qkv_stride_bytes,
1780                    q_norm_w,
1781                    k_norm_w,
1782                    &self.rope.cos,
1783                    &self.rope.sin,
1784                    self.scratch
1785                        .paged_batch_q
1786                        .as_mut()
1787                        .expect("paged_batch_q missing"),
1788                    (i as u64) * q_head_major_size_bytes,
1789                    pool_k,
1790                    pool_v,
1791                    bt_safe,
1792                    1,
1793                    nh,
1794                    nkv,
1795                    hd,
1796                    pos_i,
1797                    eps,
1798                    qk_mode,
1799                    cache_len_before,
1800                    block_size,
1801                    max_blocks_per_seq,
1802                )
1803                .expect("split_qkv_norm_rope_into_paged_cache (batched)");
1804            }
1805
1806            // Step 2: bump cache.len and stack block_tables + context_lens
1807            // host-side, then upload to device scratch.
1808            let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
1809            let mut stacked_cl: Vec<u32> = vec![0u32; m];
1810            for (i, (cache_id, _, _)) in batch.iter().enumerate() {
1811                let caches = self
1812                    .kv_caches
1813                    .get_mut(cache_id)
1814                    .expect("paged batched: cache not present");
1815                let cache = &mut caches[li];
1816                cache.len += 1;
1817                let len = cache.len as u32;
1818                stacked_cl[i] = len;
1819                let blocks = &cache.paged_block_indices;
1820                let n_to_copy = blocks.len().min(max_blocks_per_seq);
1821                stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
1822                    .copy_from_slice(&blocks[..n_to_copy]);
1823            }
1824            let bt_buf = self
1825                .scratch
1826                .paged_batch_block_tables
1827                .as_mut()
1828                .expect("paged_batch_block_tables missing");
1829            B::write_u32(ctx, bt_buf, &stacked_bt);
1830            let cl_buf = self
1831                .scratch
1832                .paged_batch_context_lens
1833                .as_mut()
1834                .expect("paged_batch_context_lens missing");
1835            B::write_u32(ctx, cl_buf, &stacked_cl);
1836
1837            // Step 3: one batched paged_decode_attention(num_seqs=m).
1838            let bt_ptr =
1839                self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
1840            let cl_ptr =
1841                self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
1842            let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
1843            let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
1844            // SAFETY: scratch buffers are not aliased; we hold &mut self
1845            // through this entire block.
1846            let bt_safe = unsafe { &*bt_ptr };
1847            let cl_safe = unsafe { &*cl_ptr };
1848            let q_safe = unsafe { &*q_ptr };
1849            let o_safe = unsafe { &mut *o_ptr };
1850            B::paged_decode_attention(
1851                ctx,
1852                q_safe,
1853                pool_k,
1854                pool_v,
1855                o_safe,
1856                bt_safe,
1857                cl_safe,
1858                m,
1859                nh,
1860                nkv,
1861                hd,
1862                block_size,
1863                max_blocks_per_seq,
1864                1, // q_len
1865            )
1866            .expect("paged batched decode");
1867
1868            // Step 4: per-item copy paged_batch_o[i] → attn_flat[i*q_dim].
1869            for i in 0..m {
1870                B::copy_slice(
1871                    ctx,
1872                    self.scratch.paged_batch_o.as_ref().unwrap(),
1873                    i * q_dim,
1874                    &mut self.scratch.attn_flat,
1875                    i * q_dim,
1876                    q_dim,
1877                );
1878            }
1879
1880            stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
1881        } else {
1882            // 3. split_qkv [M, QKV] → q_buf [M, Q], k_buf [M, KV], v_buf [M, KV]
1883            B::split_qkv(
1884                ctx,
1885                &self.scratch.qkv_out,
1886                &mut self.scratch.q_buf,
1887                &mut self.scratch.k_buf,
1888                &mut self.scratch.v_buf,
1889                m,
1890                q_dim,
1891                kv_dim,
1892            );
1893
1894            // 4-6. Per-item loop: rope + kv_append + attention.
1895            //      Each item has its own cache_id + pos + kv_len.
1896            let q_single = self
1897                .scratch
1898                .q_single
1899                .as_ref()
1900                .expect("q_single missing — enable_batched_decode_scratch not called")
1901                as *const B::Buffer;
1902            let k_single =
1903                self.scratch.k_single.as_ref().expect("k_single missing") as *const B::Buffer;
1904            let v_single =
1905                self.scratch.v_single.as_ref().expect("v_single missing") as *const B::Buffer;
1906            let q_hm_single =
1907                self.scratch
1908                    .q_head_major_single
1909                    .as_mut()
1910                    .expect("q_head_major_single missing") as *mut B::Buffer;
1911            let k_hm_single =
1912                self.scratch
1913                    .k_head_major_single
1914                    .as_mut()
1915                    .expect("k_head_major_single missing") as *mut B::Buffer;
1916            let v_hm_single =
1917                self.scratch
1918                    .v_head_major_single
1919                    .as_mut()
1920                    .expect("v_head_major_single missing") as *mut B::Buffer;
1921            let attn_hm_single =
1922                self.scratch
1923                    .attn_head_major_single
1924                    .as_mut()
1925                    .expect("attn_head_major_single missing") as *mut B::Buffer;
1926            // SAFETY: each Option holds a stable B::Buffer; we don't mutate
1927            // self.scratch in a way that would invalidate them inside the loop
1928            // (the kv_caches mutation is on a disjoint field).
1929
1930            // End of dense block (rms_norm + qkv_proj + split_qkv); start
1931            // per-item attention loop instrumentation.
1932            stage_end(dense_t0, ctx, &BD_DENSE_US);
1933            let attn_t0 = stage_t0();
1934
1935            for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1936                let pos_i = *pos as usize;
1937
1938                // SAFETY: borrows of disjoint scratch fields, see above.
1939                let q_single_ref = unsafe { &*q_single };
1940                let k_single_ref = unsafe { &*k_single };
1941                let v_single_ref = unsafe { &*v_single };
1942                let q_hm_single_mut = unsafe { &mut *q_hm_single };
1943                let k_hm_single_mut = unsafe { &mut *k_hm_single };
1944                let v_hm_single_mut = unsafe { &mut *v_hm_single };
1945                let attn_hm_single_mut = unsafe { &mut *attn_hm_single };
1946
1947                // Extract item i's Q/K/V slice from the batched buffers.
1948                B::copy_slice(
1949                    ctx,
1950                    &self.scratch.q_buf,
1951                    i * q_dim,
1952                    // copy_slice signature wants &mut for dst, but q_single
1953                    // is shared; we need a *mut variant — since enable_*
1954                    // gives us Option, we can do as_mut() here.
1955                    self.scratch.q_single.as_mut().unwrap(),
1956                    0,
1957                    q_dim,
1958                );
1959                B::copy_slice(
1960                    ctx,
1961                    &self.scratch.k_buf,
1962                    i * kv_dim,
1963                    self.scratch.k_single.as_mut().unwrap(),
1964                    0,
1965                    kv_dim,
1966                );
1967                B::copy_slice(
1968                    ctx,
1969                    &self.scratch.v_buf,
1970                    i * kv_dim,
1971                    self.scratch.v_single.as_mut().unwrap(),
1972                    0,
1973                    kv_dim,
1974                );
1975
1976                // qk_norm_rope with tokens=1, per-item pos.
1977                B::qk_norm_rope(
1978                    ctx,
1979                    q_single_ref,
1980                    q_norm_w,
1981                    &self.rope.cos,
1982                    &self.rope.sin,
1983                    q_hm_single_mut,
1984                    1,
1985                    nh,
1986                    hd,
1987                    pos_i,
1988                    eps,
1989                    qk_mode,
1990                );
1991                B::qk_norm_rope(
1992                    ctx,
1993                    k_single_ref,
1994                    k_norm_w,
1995                    &self.rope.cos,
1996                    &self.rope.sin,
1997                    k_hm_single_mut,
1998                    1,
1999                    nkv,
2000                    hd,
2001                    pos_i,
2002                    eps,
2003                    qk_mode,
2004                );
2005                B::qk_norm_rope(
2006                    ctx,
2007                    v_single_ref,
2008                    dummy_w,
2009                    &self.rope.cos,
2010                    &self.rope.sin,
2011                    v_hm_single_mut,
2012                    1,
2013                    nkv,
2014                    hd,
2015                    pos_i,
2016                    eps,
2017                    0,
2018                );
2019
2020                // KV append + attention for item i's cache.
2021                let caches = self
2022                    .kv_caches
2023                    .get_mut(cache_id)
2024                    .expect("ensure_kv must be called before forward_layer_batched");
2025                let cache = &mut caches[li];
2026                B::kv_cache_append_head_major(
2027                    ctx,
2028                    &mut cache.k,
2029                    &mut cache.v,
2030                    cache.len,
2031                    cache.capacity,
2032                    k_hm_single_mut,
2033                    v_hm_single_mut,
2034                    1,
2035                    nkv,
2036                    hd,
2037                );
2038                cache.len += 1;
2039                let kv_len = cache.len;
2040                let kv_stride = cache.capacity;
2041
2042                let attn_cfg = ferrum_kernels::backend::AttnConfig {
2043                    num_heads: nh,
2044                    num_kv_heads: nkv,
2045                    head_dim: hd,
2046                    causal: true,
2047                    scale: 1.0 / (hd as f32).sqrt(),
2048                    kv_seq_stride: kv_stride,
2049                    sliding_window: cfg_base.sliding_window,
2050                };
2051                B::flash_attention(
2052                    ctx,
2053                    q_hm_single_mut,
2054                    &cache.k,
2055                    &cache.v,
2056                    attn_hm_single_mut,
2057                    1,
2058                    1,
2059                    kv_len,
2060                    pos_i,
2061                    &attn_cfg,
2062                );
2063
2064                // Untranspose head-major → token-major: for tokens=1 the
2065                // layouts are byte-identical, so copy_slice straight into
2066                // attn_flat at the per-item offset (saves a transpose).
2067                B::copy_slice(
2068                    ctx,
2069                    attn_hm_single_mut,
2070                    0,
2071                    &mut self.scratch.attn_flat,
2072                    i * q_dim,
2073                    q_dim,
2074                );
2075            }
2076
2077            // End of per-item attention loop.
2078            stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
2079        } // end of `else` for non-paged path
2080
2081        let post_attn_t0 = stage_t0();
2082
2083        // 7. o_proj GEMM at m=M: attn_flat [M, Q] → o_proj_out [M, H]
2084        attn_layer.o_proj.forward(
2085            ctx,
2086            &self.scratch.attn_flat,
2087            &mut self.scratch.o_proj_out,
2088            m,
2089        );
2090
2091        // 8. fused residual_add + post_attention_layernorm
2092        B::fused_add_rms_norm(
2093            ctx,
2094            residual,
2095            &self.scratch.o_proj_out,
2096            &attn_layer.post_ln_w,
2097            eps,
2098            &mut self.scratch.norm_out,
2099            m,
2100            h,
2101        );
2102
2103        // o_proj + post-norm count under DENSE.
2104        stage_end(post_attn_t0, ctx, &BD_DENSE_US);
2105        let moe_t0 = stage_t0();
2106
2107        // 9. Router gemv: norm_out [M, H] → router_logits [M, n_exp]
2108        let moe_layer = &self.moe_layers[li];
2109        moe_layer.router.forward(
2110            ctx,
2111            &self.scratch.norm_out,
2112            &mut self.scratch.router_logits,
2113            m,
2114        );
2115
2116        // 10. MoE expert dispatch — per-item loop using the cheap
2117        //     stacked decode kernels (gemv_quant_moe_id + silu_mul_stacked
2118        //     + weighted_sum_batched). NOT the batched prefill path:
2119        //     `moe_forward_batched_prefill_impl` is tuned for large M
2120        //     (prefill) and the GPU bucketing overhead
2121        //     (`compute_ids_tpe_gpu` + indirect-dispatch arg-buffer
2122        //     setup) costs more than M sequential gemv calls at small M.
2123        //
2124        // Strategy: route ALL M tokens once via batched
2125        // `route_topk_softmax`, then loop M iterations of the stacked
2126        // decode kernels. Each iteration:
2127        //   - extract item i's selected ids + weights from the M-batch
2128        //     buffers via copy_slice
2129        //   - copy norm_out[i*h..(i+1)*h] → x_single
2130        //   - 3× gemv_quant_moe_id (gate/up/down) reading from x_single
2131        //   - silu_mul_stacked
2132        //   - weighted_sum_batched(batch=1) → acc_buf  (fresh write,
2133        //     no residual fusion)
2134        //   - copy_slice acc_buf → moe_out[i*h..(i+1)*h]
2135        // After the loop, single add_inplace residual += moe_out [M, H].
2136        let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
2137            && moe_layer.experts.up_stacked.is_some()
2138            && moe_layer.experts.down_stacked.is_some();
2139        // MoE FFN dispatch tiers (m = batch size of this layer call):
2140        //
2141        //   m = 1          : `moe_forward_stacked_decode_impl`
2142        //                    (decode m=1 fast path, fused gate+up+silu)
2143        //   m ≥ 8 (default): `moe_forward_batched_prefill_impl`
2144        //                    (GEMM with simdgroup_matmul + GPU bucketing)
2145        //   else (m=2..7)  : per-item stacked decode loop
2146        //
2147        // EXPERIMENTAL — opt-in `FERRUM_MOE_BATCHED_DECODE=1` engages the
2148        // new `moe_forward_batched_decode_impl` for 2 ≤ m < 32. The
2149        // kernel itself is bitwise correct and ports llama.cpp's
2150        // `kernel_mul_mv_id` strategy to ferrum (one indirect-dispatch
2151        // GEMV per linear covering all m*top_k pairs). Empirically OFF
2152        // by default because the existing `forward_layer_batched_decode`
2153        // attention plumbing (per-item copy_slice × m × 6 dispatches)
2154        // scales linearly with m and overshadows the FFN savings —
2155        // regression measured at -19% (c=4) and -36% (c=16) on
2156        // Qwen3-30B-A3B Q4_K_M / M1 Max. Closing that gap requires a
2157        // batched attention path with offset-aware QKV slicing, which
2158        // is the next PR's job. Until then the kernel sits as
2159        // infrastructure.
2160        // Two independent thresholds:
2161        //   * `FERRUM_MOE_BATCH_THRESHOLD` (default 8) — m above which
2162        //     the LEGACY non-experimental path uses the prefill GEMM.
2163        //     Shared with `decode_batch`'s engine-level gate, so users
2164        //     who set it to a small value to engage batched decode
2165        //     don't accidentally also push the inner FFN to GEMM.
2166        //   * `FERRUM_MOE_PREFILL_THRESHOLD` (default 32) — m above
2167        //     which the EXPERIMENTAL batched-decode path defers to the
2168        //     prefill GEMM path. Mirrors llama.cpp's `ne21_mm_id_min=32`
2169        //     GEMV→GEMM boundary.
2170        let legacy_prefill_threshold: usize = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
2171            .ok()
2172            .and_then(|s| s.parse().ok())
2173            .unwrap_or(8);
2174        let new_prefill_threshold: usize = std::env::var("FERRUM_MOE_PREFILL_THRESHOLD")
2175            .ok()
2176            .and_then(|s| s.parse().ok())
2177            .unwrap_or(32);
2178        let new_batched_enabled = stacked_path_available
2179            && B::supports_batched_moe_gemv()
2180            && std::env::var("FERRUM_MOE_BATCHED_DECODE").as_deref() == Ok("1");
2181
2182        // When the new path is opted in, it owns the m=2..new_prefill_threshold
2183        // range; the legacy threshold is overridden upward.
2184        let use_prefill_batched = if new_batched_enabled {
2185            stacked_path_available && m >= new_prefill_threshold
2186        } else {
2187            stacked_path_available && m >= legacy_prefill_threshold
2188        };
2189        let use_batched_decode = new_batched_enabled && !use_prefill_batched && m >= 2;
2190
2191        if use_prefill_batched {
2192            moe_forward_batched_prefill_impl::<B>(
2193                ctx,
2194                moe_layer,
2195                &mut self.scratch,
2196                h,
2197                self.cfg.expert_intermediate_size,
2198                self.cfg.num_experts_per_tok,
2199                self.cfg.num_experts,
2200                self.cfg.norm_topk_prob,
2201                m,
2202            )?;
2203        } else if use_batched_decode {
2204            moe_forward_batched_decode_impl::<B>(
2205                ctx,
2206                moe_layer,
2207                &mut self.scratch,
2208                h,
2209                self.cfg.expert_intermediate_size,
2210                self.cfg.num_experts_per_tok,
2211                self.cfg.num_experts,
2212                self.cfg.norm_topk_prob,
2213                m,
2214            )?;
2215        } else if stacked_path_available {
2216            let inter = self.cfg.expert_intermediate_size;
2217            let top_k = self.cfg.num_experts_per_tok;
2218            let n_exp = self.cfg.num_experts;
2219            let norm_topk_prob = self.cfg.norm_topk_prob;
2220            let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2221            let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2222            let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2223
2224            // Single batched router pass: writes selected_ids_buf [M, top_k]
2225            // and weights_2d [M, top_k]. Replaces M individual route calls.
2226            B::route_topk_softmax(
2227                ctx,
2228                &self.scratch.router_logits,
2229                &mut self.scratch.selected_ids_buf,
2230                &mut self.scratch.weights_2d,
2231                m,
2232                n_exp,
2233                top_k,
2234                norm_topk_prob,
2235            )?;
2236
2237            // Per-item loop using offset-aware kernel APIs — eliminates
2238            // the 4 copy_slice round-trips per iteration that the
2239            // earlier implementation needed (ids, weights, x_single,
2240            // moe_out). At c=16 / 48 layers that's ~3,072 dispatches
2241            // saved per token. Uses `gemv_quant_moe_id_offset` to read
2242            // `selected_ids_buf` at the i-th `top_k` block and
2243            // `norm_out` at the i-th hidden row directly. Falls back
2244            // to copy_slice path if backend doesn't support offsets.
2245            for i in 0..m {
2246                let ids_offset = i * top_k;
2247                let activation_offset = i * h;
2248                let weights_offset = i * top_k;
2249                let moe_out_offset = i * h;
2250
2251                // Stacked gate / up gemvs — broadcast item i's row of
2252                // norm_out across top_k slots, read item i's ids.
2253                let gate_res = B::gemv_quant_moe_id_offset(
2254                    ctx,
2255                    &self.scratch.norm_out,
2256                    activation_offset,
2257                    gate_stacked,
2258                    &self.scratch.selected_ids_buf,
2259                    ids_offset,
2260                    &mut self.scratch.gate_out_stacked,
2261                    top_k,
2262                    0,
2263                );
2264                if gate_res.is_err() {
2265                    // Backend doesn't support offset variants — fall back
2266                    // to the legacy copy_slice path. Same as before.
2267                    B::copy_slice(
2268                        ctx,
2269                        &self.scratch.selected_ids_buf,
2270                        ids_offset,
2271                        &mut self.scratch.ids_buf,
2272                        0,
2273                        top_k,
2274                    );
2275                    B::copy_slice(
2276                        ctx,
2277                        &self.scratch.weights_2d,
2278                        weights_offset,
2279                        &mut self.scratch.weights_buf,
2280                        0,
2281                        top_k,
2282                    );
2283                    B::copy_slice(
2284                        ctx,
2285                        &self.scratch.norm_out,
2286                        activation_offset,
2287                        &mut self.scratch.x_single,
2288                        0,
2289                        h,
2290                    );
2291                    B::gemv_quant_moe_id(
2292                        ctx,
2293                        &self.scratch.x_single,
2294                        gate_stacked,
2295                        &self.scratch.ids_buf,
2296                        &mut self.scratch.gate_out_stacked,
2297                        top_k,
2298                        0,
2299                    )?;
2300                    B::gemv_quant_moe_id(
2301                        ctx,
2302                        &self.scratch.x_single,
2303                        up_stacked,
2304                        &self.scratch.ids_buf,
2305                        &mut self.scratch.up_out_stacked,
2306                        top_k,
2307                        0,
2308                    )?;
2309                    B::silu_mul_stacked(
2310                        ctx,
2311                        &self.scratch.gate_out_stacked,
2312                        &self.scratch.up_out_stacked,
2313                        &mut self.scratch.silu_stacked,
2314                        top_k,
2315                        inter,
2316                    )?;
2317                    B::gemv_quant_moe_id(
2318                        ctx,
2319                        &self.scratch.silu_stacked,
2320                        down_stacked,
2321                        &self.scratch.ids_buf,
2322                        &mut self.scratch.down_out_stacked,
2323                        top_k,
2324                        inter,
2325                    )?;
2326                    B::weighted_sum_batched(
2327                        ctx,
2328                        &self.scratch.down_out_stacked,
2329                        &self.scratch.weights_buf,
2330                        &mut self.scratch.acc_buf,
2331                        1,
2332                        top_k,
2333                        h,
2334                    )?;
2335                    B::copy_slice(
2336                        ctx,
2337                        &self.scratch.acc_buf,
2338                        0,
2339                        &mut self.scratch.moe_out,
2340                        moe_out_offset,
2341                        h,
2342                    );
2343                    continue;
2344                }
2345                // Fast path: offset-aware all the way through.
2346                B::gemv_quant_moe_id_offset(
2347                    ctx,
2348                    &self.scratch.norm_out,
2349                    activation_offset,
2350                    up_stacked,
2351                    &self.scratch.selected_ids_buf,
2352                    ids_offset,
2353                    &mut self.scratch.up_out_stacked,
2354                    top_k,
2355                    0,
2356                )?;
2357                B::silu_mul_stacked(
2358                    ctx,
2359                    &self.scratch.gate_out_stacked,
2360                    &self.scratch.up_out_stacked,
2361                    &mut self.scratch.silu_stacked,
2362                    top_k,
2363                    inter,
2364                )?;
2365                B::gemv_quant_moe_id_offset(
2366                    ctx,
2367                    &self.scratch.silu_stacked,
2368                    0, // silu_stacked itself stays at offset 0 each iter
2369                    down_stacked,
2370                    &self.scratch.selected_ids_buf,
2371                    ids_offset,
2372                    &mut self.scratch.down_out_stacked,
2373                    top_k,
2374                    inter,
2375                )?;
2376                // Write directly into moe_out at the per-item offset —
2377                // skips the copy_slice from acc_buf.
2378                B::weighted_sum_batched_offset(
2379                    ctx,
2380                    &self.scratch.down_out_stacked,
2381                    &self.scratch.weights_2d,
2382                    weights_offset,
2383                    &mut self.scratch.moe_out,
2384                    moe_out_offset,
2385                    1,
2386                    top_k,
2387                    h,
2388                )?;
2389            }
2390        } else {
2391            // Backend without stacked variants — fall back to the legacy
2392            // per-(token, expert) host-routed path.
2393            moe_forward::<B>(
2394                ctx,
2395                &self.scratch.norm_out,
2396                &self.scratch.router_logits,
2397                &mut self.scratch.moe_out,
2398                m,
2399                h,
2400                self.cfg.expert_intermediate_size,
2401                self.cfg.num_experts,
2402                self.cfg.num_experts_per_tok,
2403                self.cfg.norm_topk_prob,
2404                &moe_layer.experts,
2405                &mut self.scratch.x_single,
2406                &mut self.scratch.acc_buf,
2407                &mut self.scratch.gate_up_buf,
2408                &mut self.scratch.silu_buf,
2409                &mut self.scratch.down_buf,
2410                &self.scratch.zero_hidden,
2411            )?;
2412        }
2413
2414        // 11. residual += moe_out [M, H]
2415        B::add_inplace(ctx, residual, &self.scratch.moe_out, m * h);
2416
2417        // Close MoE-block instrumentation (router + FFN + residual add).
2418        stage_end(moe_t0, ctx, &BD_MOE_US);
2419
2420        Ok(())
2421    }
2422}
2423
2424impl<B: Backend> DecoderOnlyLLM for Qwen3MoeModel<B> {
2425    fn config(&self) -> &LlmRuntimeConfig {
2426        &self.runtime_cfg
2427    }
2428
2429    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2430        // Eager scratch + KV cache grow + a 1-token forward warmup so
2431        // the first real prefill / decode doesn't pay the cold-start
2432        // ~25-MTLBuffer scratch alloc + ~96-MTLBuffer KV alloc + Metal
2433        // pipeline-state first-bind costs (~265 ms total on Qwen3-MoE
2434        // 30B-A3B / M1 Max). Mirrors what llama-bench's --warmup does
2435        // (which runs a same-shape forward before the timer).
2436        self.ensure_scratch(max_tokens);
2437        self.ensure_kv(cache_id);
2438
2439        // Warmup forward through all 48 layers under a scratch cache_id
2440        // so the real `cache_id` starts at pos_offset=0. Token 0 is
2441        // valid for any tokenizer (BOS or pad).
2442        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2443        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2444        // Drop the warmup KV cache slot — real cache_id is unaffected.
2445        if let Some(caches) = self.kv_caches.remove(WARMUP_CACHE) {
2446            self.kv_free_pool.push(caches);
2447        }
2448    }
2449
2450    fn kv_capacity(&self) -> usize {
2451        // Mirror the bound `ensure_kv` will use when allocating the cache.
2452        let model_max = self.cfg.base.max_seq_len;
2453        const DEFAULT_KV_CAPACITY: usize = 4096;
2454        std::env::var("FERRUM_KV_CAPACITY")
2455            .ok()
2456            .and_then(|s| s.parse::<usize>().ok())
2457            .map(|cap| cap.min(model_max))
2458            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
2459    }
2460
2461    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2462        self.prefill_internal(cache_id, tokens)
2463    }
2464
2465    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2466        self.decode_internal(cache_id, token, pos)
2467    }
2468
2469    // decode_batch is gated to use the batched path only when it's a
2470    // measurable win. The crossover depends on M:
2471    //
2472    //   - At low M (≤ ~8) the per-item `decode_internal` loop wins
2473    //     because: (a) it stays at scratch offset 0 (no copy_slice
2474    //     overhead), (b) it preserves the cross-layer rms_norm fusion
2475    //     fast path (`weighted_sum_residual_norm_stacked`).
2476    //   - At high M (≥ ~12) the batched path wins because the dense
2477    //     GEMM batching (qkv_proj, o_proj, router, lm_head at m=M) and
2478    //     the prefill-batched MoE dispatch (one `gemm_quant_moe_id` for
2479    //     all tokens) amortise the ~48-dispatch lost-fusion penalty.
2480    //
2481    // Default opted out of FERRUM_MOE_BATCHED. When opted in, the
2482    // batched path engages only at M ≥ FERRUM_MOE_BATCH_THRESHOLD
2483    // (default 12). Below that we still go per-item.
2484    //
2485    // Empirical note 2026-05-02: a follow-up PR added a batched MoE
2486    // GEMV kernel (`gemv_quant_moe_id_batched`) that holds MoE
2487    // dispatch count flat as concurrency scales. Wiring it through
2488    // `decode_batch_internal` regressed throughput by 19% (c=4) /
2489    // 36% (c=16) — `forward_layer_batched_decode`'s per-item
2490    // attention plumbing (copy_slice × m × 6 dispatches) costs more
2491    // than the MoE save. The batched MoE kernel is shipped as opt-in
2492    // infrastructure (`FERRUM_MOE_BATCHED_DECODE=1`); flipping it on
2493    // by default has to wait until the attention plumbing is fixed.
2494    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2495        let m = batch.len();
2496        let opted_in = std::env::var("FERRUM_MOE_BATCHED").as_deref() == Ok("1");
2497        let threshold = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
2498            .ok()
2499            .and_then(|s| s.parse::<usize>().ok())
2500            .unwrap_or(12);
2501        if opted_in && m >= threshold {
2502            self.decode_batch_internal(batch)
2503        } else {
2504            batch
2505                .iter()
2506                .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
2507                .collect()
2508        }
2509    }
2510
2511    fn release(&mut self, cache_id: &str) {
2512        let mut ctx = B::new_context();
2513        B::sync(&mut ctx);
2514        B::reset_graph(&mut ctx);
2515        B::sync(&mut ctx);
2516        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2517            // Paged mode: return the cache_id's blocks to the shared
2518            // allocator so other sequences can reuse them. Without this,
2519            // every request consumes max_blocks_per_seq blocks
2520            // permanently — pool exhausts after FERRUM_PAGED_MAX_SEQS
2521            // requests and subsequent ensure_kv panics with
2522            // "scratch residual missing" (the cascade panic from a
2523            // failed ensure_kv path leaving scratch poisoned).
2524            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2525                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2526                if let Some(c0) = caches.first() {
2527                    if !c0.paged_block_indices.is_empty() {
2528                        alloc.free(&c0.paged_block_indices);
2529                    }
2530                }
2531                for c in caches.iter_mut() {
2532                    c.paged_block_indices.clear();
2533                }
2534            }
2535            self.kv_free_pool.push(caches);
2536        }
2537    }
2538
2539    fn reset(&mut self) {
2540        let mut ctx = B::new_context();
2541        B::sync(&mut ctx);
2542        B::reset_graph(&mut ctx);
2543        B::sync(&mut ctx);
2544        self.kv_caches.clear();
2545        self.kv_free_pool.clear();
2546    }
2547}
2548
2549/// Batched MoE FFN — decode (m=1) and per-token-prefill (m>1 looped).
2550///
2551/// Three batched `gemv_quant_moe_id` dispatches per token: gate (broadcast
2552/// activation), up (broadcast activation), down (per-slot activation —
2553/// each expert sees its own silu·up). The per-(token, expert) outer loop
2554/// shrinks from `top_k * 4` dispatches per layer to **3 batched + 1
2555/// silu_mul_split + 1 weighted_sum_dispatch_loop**.
2556///
2557/// For prefill (m > 1) we loop over tokens externally — each token's
2558/// router output drives a single batched call. Still much faster than
2559/// the per-(token, expert) per-Linear path because the gemvs are batched.
2560///
2561/// Free function (not a method) so the caller can split the borrow on
2562/// `self` between `moe_layers[li]` (immutable) and `scratch` (mutable).
2563#[allow(clippy::too_many_arguments)]
2564fn moe_forward_stacked_decode_impl<B: Backend>(
2565    ctx: &mut B::Context,
2566    moe_layer: &Qwen3MoeLayerState<B>,
2567    scratch: &mut Qwen3MoeScratch<B>,
2568    h: usize,
2569    inter: usize,
2570    top_k: usize,
2571    n_exp: usize,
2572    norm_topk_prob: bool,
2573    tokens: usize,
2574    residual: &mut B::Buffer,
2575    // If `Some`, fold the NEXT layer's leading rms_norm into the
2576    // weighted-sum-residual tail using `weighted_sum_residual_norm_stacked`.
2577    next_norm_w: Option<&B::Buffer>,
2578    eps: f32,
2579) -> Result<()> {
2580    // GPU-side routing: one Metal launch reads router_logits and writes
2581    // selected ids + combine weights directly into device-side scratch
2582    // buffers. Eliminates the per-layer `B::sync + B::to_vec(router_logits)
2583    // + host route()` round trip — the dominant remaining cost in the
2584    // decode hot path (~10% of total decode latency).
2585    let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
2586    let stage_t0 = || -> Option<std::time::Instant> {
2587        if prof {
2588            Some(std::time::Instant::now())
2589        } else {
2590            None
2591        }
2592    };
2593    let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
2594        if let Some(t) = t0 {
2595            B::sync(ctx);
2596            c.fetch_add(
2597                t.elapsed().as_micros() as u64,
2598                std::sync::atomic::Ordering::Relaxed,
2599            );
2600        }
2601    };
2602
2603    let t0 = stage_t0();
2604    B::route_topk_softmax(
2605        ctx,
2606        &scratch.router_logits,
2607        &mut scratch.ids_buf,
2608        &mut scratch.weights_buf,
2609        tokens,
2610        n_exp,
2611        top_k,
2612        norm_topk_prob,
2613    )?;
2614    stage_end(t0, ctx, &DEC_ROUTE_US);
2615
2616    let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2617    let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2618    let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2619
2620    // moe_forward_stacked_decode_impl is only called when `tokens == 1`
2621    // (the branch in `forward_layer` routes prefill m>1 through
2622    // `moe_forward_batched_prefill_impl` instead). The for-b loop and
2623    // the copy norm_out[b*h] → x_single were vestigial scaffolding;
2624    // for tokens=1 norm_out[0..h] IS the activation row, and we can
2625    // pass it straight to the gemv kernel via src1_stride=0 broadcast.
2626    debug_assert_eq!(
2627        tokens, 1,
2628        "moe_forward_stacked_decode_impl expects tokens=1 (prefill goes through moe_forward_batched_prefill_impl)"
2629    );
2630    let _ = tokens; // silence unused-warning when assertion is compiled out
2631
2632    {
2633        // ids_buf and weights_buf populated by the GPU router above —
2634        // no host writes needed here in the decode path.
2635
2636        // Fused-vs-unfused gate+up+silu selection.
2637        //
2638        // Default: when the backend advertises support (Metal Q4KExperts),
2639        // run the single fused dispatch — saves 2 dispatches and the
2640        // entire round-trip through gate_out_stacked / up_out_stacked
2641        // scratch (≈4× [top_k, ffn] of intermediate bandwidth).
2642        //
2643        // Opt-out: `FERRUM_MOE_FUSED_GATE_UP_SILU=0` forces the legacy
2644        // 3-dispatch path. Used for A/B benchmarking and as a kill switch
2645        // if the fused kernel ever produces divergent outputs.
2646        // Cache the env-flag read once per process — the decode hot
2647        // path calls this fn ~48 layers × ~steps_per_run times.
2648        static FUSED_DISABLED: OnceLock<bool> = OnceLock::new();
2649        let fused_disabled = *FUSED_DISABLED
2650            .get_or_init(|| std::env::var("FERRUM_MOE_FUSED_GATE_UP_SILU").as_deref() == Ok("0"));
2651        let use_fused = B::supports_fused_moe_gate_up_silu() && !fused_disabled;
2652
2653        if use_fused {
2654            // 1+2+3 fused: silu_stacked = SiLU(gate · norm_out) * (up · norm_out)
2655            let t0 = stage_t0();
2656            B::gemv_quant_moe_id_gate_up_silu(
2657                ctx,
2658                &scratch.norm_out,
2659                gate_stacked,
2660                up_stacked,
2661                &scratch.ids_buf,
2662                &mut scratch.silu_stacked,
2663                top_k,
2664            )?;
2665            stage_end(t0, ctx, &DEC_SILU_US);
2666        } else {
2667            // 1. Batched gate gemv — broadcast input across top_k slots.
2668            //    src1 = norm_out (which has hidden floats at offset 0),
2669            //    src1_stride=0 → all slots read the same row.
2670            let t0 = stage_t0();
2671            B::gemv_quant_moe_id(
2672                ctx,
2673                &scratch.norm_out,
2674                gate_stacked,
2675                &scratch.ids_buf,
2676                &mut scratch.gate_out_stacked,
2677                top_k,
2678                0, // broadcast
2679            )?;
2680            stage_end(t0, ctx, &DEC_GATE_US);
2681
2682            // 2. Batched up gemv — also broadcast.
2683            let t0 = stage_t0();
2684            B::gemv_quant_moe_id(
2685                ctx,
2686                &scratch.norm_out,
2687                up_stacked,
2688                &scratch.ids_buf,
2689                &mut scratch.up_out_stacked,
2690                top_k,
2691                0,
2692            )?;
2693            stage_end(t0, ctx, &DEC_UP_US);
2694
2695            // 3. Stacked SiLU·gate → silu_stacked. Single dispatch covers
2696            //    all top_k slots — replaces the per-slot loop's
2697            //    (3 copy_slice + 1 silu_mul) × 8 = 32 dispatches.
2698            let t0 = stage_t0();
2699            B::silu_mul_stacked(
2700                ctx,
2701                &scratch.gate_out_stacked,
2702                &scratch.up_out_stacked,
2703                &mut scratch.silu_stacked,
2704                top_k,
2705                inter,
2706            )?;
2707            stage_end(t0, ctx, &DEC_SILU_US);
2708        }
2709
2710        // 4. Batched down gemv — per-slot input via src1_stride = inter.
2711        //    silu_stacked[k * inter ..] is the activation row for slot k.
2712        let t0 = stage_t0();
2713        B::gemv_quant_moe_id(
2714            ctx,
2715            &scratch.silu_stacked,
2716            down_stacked,
2717            &scratch.ids_buf,
2718            &mut scratch.down_out_stacked,
2719            top_k,
2720            inter,
2721        )?;
2722        stage_end(t0, ctx, &DEC_DOWN_US);
2723
2724        // 5. Fused weighted-sum + residual-add (+ optional next-layer
2725        //    rms_norm). Two paths:
2726        //
2727        //    * `next_norm_w = Some(_)` (cross-layer fusion): one kernel
2728        //      computes residual[i] += Σ_k w[k] · down[k, i] AND
2729        //      norm_out[i] = residual[i] · scale · next_norm_w[i].
2730        //      The next layer's leading rms_norm is skipped. Saves an
2731        //      additional dispatch per layer transition.
2732        //    * `next_norm_w = None` (last layer): just residual-add.
2733        let t0 = stage_t0();
2734        if let Some(nnw) = next_norm_w {
2735            B::weighted_sum_residual_norm_stacked(
2736                ctx,
2737                &scratch.down_out_stacked,
2738                &scratch.weights_buf,
2739                residual,
2740                nnw,
2741                &mut scratch.norm_out,
2742                top_k,
2743                h,
2744                eps,
2745            )?;
2746        } else {
2747            B::weighted_sum_residual_stacked(
2748                ctx,
2749                &scratch.down_out_stacked,
2750                &scratch.weights_buf,
2751                residual,
2752                top_k,
2753                h,
2754            )?;
2755        }
2756        stage_end(t0, ctx, &DEC_WSUM_US);
2757    }
2758
2759    Ok(())
2760}
2761
2762/// Batched MoE FFN for prefill (m > 1).
2763///
2764/// One pass through the expert dispatch — replaces the per-token loop
2765/// with three batched 2-D mul_mm_id dispatches (gate, up, down) where
2766/// each expert's slab of (token, slot) pairs runs as one gemm tile.
2767/// Per-layer dispatch count: ~6 (router + 3 mul_mm_id + silu + wsum)
2768/// independent of `tokens`. Compare to the decode-style stacked path
2769/// that emits ~10 per token.
2770///
2771/// Free function so the caller can split the borrow on `self` between
2772/// `moe_layers[li]` (immutable) and `scratch` (mutable).
2773#[allow(clippy::too_many_arguments)]
2774fn moe_forward_batched_prefill_impl<B: Backend>(
2775    ctx: &mut B::Context,
2776    moe_layer: &Qwen3MoeLayerState<B>,
2777    scratch: &mut Qwen3MoeScratch<B>,
2778    h: usize,
2779    inter: usize,
2780    top_k: usize,
2781    n_exp: usize,
2782    norm_topk_prob: bool,
2783    tokens: usize,
2784) -> Result<()> {
2785    let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
2786    let stage_t0 = || -> Option<std::time::Instant> {
2787        if prof {
2788            Some(std::time::Instant::now())
2789        } else {
2790            None
2791        }
2792    };
2793    let stage_end =
2794        |t0: Option<std::time::Instant>, ctx: &mut B::Context, us: &AtomicU64, n: &AtomicU64| {
2795            if let Some(t) = t0 {
2796                B::sync(ctx);
2797                us.fetch_add(
2798                    t.elapsed().as_micros() as u64,
2799                    std::sync::atomic::Ordering::Relaxed,
2800                );
2801                n.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2802            }
2803        };
2804
2805    // GPU-side routing: keep the whole pipeline device-resident. Two
2806    // dispatches replace the per-layer `B::sync + to_vec(router_logits)
2807    // + host route() + host compute_ids_tpe + write_back` round trip.
2808    //
2809    //   1. `route_topk_softmax` writes selected expert IDs (flat
2810    //      `[batch, top_k]`) into `selected_ids_buf` and the post-renorm
2811    //      combine weights directly into `weights_2d`.
2812    //   2. `compute_ids_tpe_gpu` buckets those pairs into `tpe_buf` and
2813    //      `ids_2d` using device-side atomic_fetch_add slot claims. The
2814    //      `ids_2d` row stride is the worst-case `tokens * top_k`; the
2815    //      consumer GEMM stops at `tpe[e]` so the over-strided columns
2816    //      cost only launch overhead, not real compute.
2817    //
2818    // `FERRUM_MOE_HOST_TOPK=1`        → legacy CPU softmax+topk+bucket
2819    // `FERRUM_MOE_DIRECT_DISPATCH=1`  → GPU topk but worst-case GEMM grid
2820    // (default)                       → GPU topk + indirect-dispatched GEMM
2821    //                                    (grid sized from max(tpe[e]))
2822    let use_gpu_topk = std::env::var("FERRUM_MOE_HOST_TOPK").as_deref() != Ok("1");
2823    let use_indirect_dispatch =
2824        use_gpu_topk && std::env::var("FERRUM_MOE_DIRECT_DISPATCH").as_deref() != Ok("1");
2825    let max_per_expert = if use_gpu_topk {
2826        let t0 = stage_t0();
2827        B::route_topk_softmax(
2828            ctx,
2829            &scratch.router_logits,
2830            &mut scratch.selected_ids_buf,
2831            &mut scratch.weights_2d,
2832            tokens,
2833            n_exp,
2834            top_k,
2835            norm_topk_prob,
2836        )?;
2837        B::compute_ids_tpe_gpu(
2838            ctx,
2839            &scratch.selected_ids_buf,
2840            &mut scratch.tpe_buf,
2841            &mut scratch.ids_2d,
2842            &mut scratch.gate_up_args_buf,
2843            &mut scratch.down_args_buf,
2844            tokens,
2845            n_exp,
2846            top_k,
2847            inter,
2848            h,
2849        )?;
2850        stage_end(
2851            t0,
2852            ctx,
2853            &MOE_PREFILL_HOST_TOPK_US,
2854            &MOE_PREFILL_HOST_TOPK_CALLS,
2855        );
2856        // Worst-case ids row stride; matches `dispatch_compute_ids_tpe`.
2857        tokens * top_k
2858    } else {
2859        use ferrum_kernels::moe_host::compute_ids_tpe;
2860        let t0 = stage_t0();
2861        B::sync(ctx);
2862        let logits_host = B::to_vec(&scratch.router_logits, tokens * n_exp);
2863        let route = crate::moe::router::route(&logits_host, tokens, n_exp, top_k, norm_topk_prob);
2864        let (tpe_host, ids_host, max_per_expert) =
2865            compute_ids_tpe(&route.expert_ids, n_exp, tokens, top_k);
2866        B::write_i32_into(&mut scratch.tpe_buf, &tpe_host);
2867        B::write_i32_into(&mut scratch.ids_2d, &ids_host);
2868        B::write_f32_into(&mut scratch.weights_2d, &route.expert_weights);
2869        stage_end(
2870            t0,
2871            ctx,
2872            &MOE_PREFILL_HOST_TOPK_US,
2873            &MOE_PREFILL_HOST_TOPK_CALLS,
2874        );
2875        max_per_expert
2876    };
2877
2878    let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2879    let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2880    let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2881
2882    // 1. Batched gate gemm — one launch covers all (token, expert) pairs.
2883    //    src1 layout: [batch, ne11=1, K] (broadcast: each pair reads its
2884    //    token's row, slot index ignored).
2885    //    dst layout:  [batch, top_k, expert_inter] — natural.
2886    let t0 = stage_t0();
2887    if use_indirect_dispatch {
2888        B::gemm_quant_moe_id_indirect(
2889            ctx,
2890            &scratch.norm_out,
2891            gate_stacked,
2892            &scratch.ids_2d,
2893            &scratch.tpe_buf,
2894            &mut scratch.gate_out_stacked,
2895            &scratch.gate_up_args_buf,
2896            1, // ne11 = 1: broadcast
2897            top_k,
2898            max_per_expert,
2899            tokens,
2900        )?;
2901    } else {
2902        B::gemm_quant_moe_id(
2903            ctx,
2904            &scratch.norm_out,
2905            gate_stacked,
2906            &scratch.ids_2d,
2907            &scratch.tpe_buf,
2908            &mut scratch.gate_out_stacked,
2909            1,
2910            top_k,
2911            max_per_expert,
2912            tokens,
2913        )?;
2914    }
2915    stage_end(t0, ctx, &MOE_PREFILL_GATE_US, &MOE_PREFILL_GATE_CALLS);
2916
2917    // 2. Batched up gemm — same shape as gate.
2918    let t0 = stage_t0();
2919    if use_indirect_dispatch {
2920        B::gemm_quant_moe_id_indirect(
2921            ctx,
2922            &scratch.norm_out,
2923            up_stacked,
2924            &scratch.ids_2d,
2925            &scratch.tpe_buf,
2926            &mut scratch.up_out_stacked,
2927            &scratch.gate_up_args_buf,
2928            1,
2929            top_k,
2930            max_per_expert,
2931            tokens,
2932        )?;
2933    } else {
2934        B::gemm_quant_moe_id(
2935            ctx,
2936            &scratch.norm_out,
2937            up_stacked,
2938            &scratch.ids_2d,
2939            &scratch.tpe_buf,
2940            &mut scratch.up_out_stacked,
2941            1,
2942            top_k,
2943            max_per_expert,
2944            tokens,
2945        )?;
2946    }
2947    stage_end(t0, ctx, &MOE_PREFILL_UP_US, &MOE_PREFILL_UP_CALLS);
2948
2949    // 3. SiLU·gate over [tokens * top_k, expert_inter] flat layout.
2950    let total_pairs = tokens * top_k;
2951    let t0 = stage_t0();
2952    B::silu_mul_batched(
2953        ctx,
2954        &scratch.gate_out_stacked,
2955        &scratch.up_out_stacked,
2956        &mut scratch.silu_stacked,
2957        total_pairs,
2958        inter,
2959    )?;
2960    stage_end(t0, ctx, &MOE_PREFILL_SILU_US, &MOE_PREFILL_SILU_CALLS);
2961
2962    // 4. Batched down gemm — src1 is [batch, top_k, expert_inter] from
2963    //    silu_stacked. ne11 = top_k → each pair reads its own row.
2964    let t0 = stage_t0();
2965    if use_indirect_dispatch {
2966        B::gemm_quant_moe_id_indirect(
2967            ctx,
2968            &scratch.silu_stacked,
2969            down_stacked,
2970            &scratch.ids_2d,
2971            &scratch.tpe_buf,
2972            &mut scratch.down_out_stacked,
2973            &scratch.down_args_buf,
2974            top_k, // ne11 = top_k: per-slot
2975            top_k,
2976            max_per_expert,
2977            tokens,
2978        )?;
2979    } else {
2980        B::gemm_quant_moe_id(
2981            ctx,
2982            &scratch.silu_stacked,
2983            down_stacked,
2984            &scratch.ids_2d,
2985            &scratch.tpe_buf,
2986            &mut scratch.down_out_stacked,
2987            top_k,
2988            top_k,
2989            max_per_expert,
2990            tokens,
2991        )?;
2992    }
2993    stage_end(t0, ctx, &MOE_PREFILL_DOWN_US, &MOE_PREFILL_DOWN_CALLS);
2994
2995    // 5. Per-batch weighted sum: moe_out[b, h] = Σ_k w[b,k] · down[b,k,h]
2996    let t0 = stage_t0();
2997    B::weighted_sum_batched(
2998        ctx,
2999        &scratch.down_out_stacked,
3000        &scratch.weights_2d,
3001        &mut scratch.moe_out,
3002        tokens,
3003        top_k,
3004        h,
3005    )?;
3006    stage_end(t0, ctx, &MOE_PREFILL_WSUM_US, &MOE_PREFILL_WSUM_CALLS);
3007
3008    Ok(())
3009}
3010
3011/// Batched MoE FFN for the **small-m decode** range (typically c=2..32).
3012///
3013/// Mirrors llama.cpp's `kernel_mul_mv_id` strategy: hold the dispatch
3014/// count flat as concurrency scales by emitting **one** batched GEMV
3015/// per linear (gate / up / down) that covers all `m * top_k`
3016/// (token, expert) pairs in a single Metal launch. Replaces the
3017/// per-token outer loop in `forward_layer` (which emitted ~5
3018/// dispatches × m tokens per layer) with a fixed-shape pipeline.
3019///
3020/// Compared to [`moe_forward_batched_prefill_impl`]:
3021///   * no `compute_ids_tpe_gpu` bucketing kernel (the new pair-indexed
3022///     GEMV reads `selected_ids_buf` directly)
3023///   * uses GEMV not GEMM (better tile utilisation when tokens-per-expert
3024///     is small — at c=16 with top_k=8 each expert sees ~1-3 token rows,
3025///     well below the simdgroup_matmul tile width)
3026///   * fewer Metal dispatches per layer (5: route + 3 gemv + silu + wsum)
3027///
3028/// Per-layer dispatch budget: 5 (independent of m). At c=16 / 48 layers
3029/// that's 240 dispatches per decode step vs the per-token loop's ~3,840.
3030#[allow(clippy::too_many_arguments)]
3031fn moe_forward_batched_decode_impl<B: Backend>(
3032    ctx: &mut B::Context,
3033    moe_layer: &Qwen3MoeLayerState<B>,
3034    scratch: &mut Qwen3MoeScratch<B>,
3035    h: usize,
3036    inter: usize,
3037    top_k: usize,
3038    n_exp: usize,
3039    norm_topk_prob: bool,
3040    tokens: usize,
3041) -> Result<()> {
3042    let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
3043    let stage_t0 = || -> Option<std::time::Instant> {
3044        if prof {
3045            Some(std::time::Instant::now())
3046        } else {
3047            None
3048        }
3049    };
3050    let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
3051        if let Some(t) = t0 {
3052            B::sync(ctx);
3053            c.fetch_add(
3054                t.elapsed().as_micros() as u64,
3055                std::sync::atomic::Ordering::Relaxed,
3056            );
3057        }
3058    };
3059
3060    let total_pairs = tokens * top_k;
3061
3062    // 1. Single batched router pass — fills selected_ids_buf [m * top_k]
3063    //    and weights_2d [m * top_k] in one Metal dispatch.
3064    let t0 = stage_t0();
3065    B::route_topk_softmax(
3066        ctx,
3067        &scratch.router_logits,
3068        &mut scratch.selected_ids_buf,
3069        &mut scratch.weights_2d,
3070        tokens,
3071        n_exp,
3072        top_k,
3073        norm_topk_prob,
3074    )?;
3075    stage_end(t0, ctx, &MOE_BATCHED_DECODE_ROUTE_US);
3076
3077    let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
3078    let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
3079    let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
3080
3081    // 2+3+4. Fused gate+up+silu — single Metal dispatch covers all
3082    // m*top_k pairs. Falls back to the 3-dispatch sequence on backends
3083    // that don't have the fused-batched kernel.
3084    if B::supports_batched_moe_gate_up_silu() {
3085        let t0 = stage_t0();
3086        B::gemv_quant_moe_id_gate_up_silu_batched(
3087            ctx,
3088            &scratch.norm_out,
3089            gate_stacked,
3090            up_stacked,
3091            &scratch.selected_ids_buf,
3092            &mut scratch.silu_stacked,
3093            tokens,
3094            top_k,
3095            h, // outer stride: K floats per token
3096            0, // inner stride: 0 (slots within a token broadcast)
3097        )?;
3098        // Charge the whole fused step to the SiLU bucket — keeps the
3099        // profile counter additive with the unfused path's silu line.
3100        stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
3101    } else {
3102        // 2. Batched gate gemv — one launch covers all m*top_k pairs.
3103        let t0 = stage_t0();
3104        B::gemv_quant_moe_id_batched(
3105            ctx,
3106            &scratch.norm_out,
3107            gate_stacked,
3108            &scratch.selected_ids_buf,
3109            &mut scratch.gate_out_stacked,
3110            tokens,
3111            top_k,
3112            h,
3113            0,
3114        )?;
3115        stage_end(t0, ctx, &MOE_BATCHED_DECODE_GATE_US);
3116
3117        // 3. Batched up gemv.
3118        let t0 = stage_t0();
3119        B::gemv_quant_moe_id_batched(
3120            ctx,
3121            &scratch.norm_out,
3122            up_stacked,
3123            &scratch.selected_ids_buf,
3124            &mut scratch.up_out_stacked,
3125            tokens,
3126            top_k,
3127            h,
3128            0,
3129        )?;
3130        stage_end(t0, ctx, &MOE_BATCHED_DECODE_UP_US);
3131
3132        // 4. SiLU·gate.
3133        let t0 = stage_t0();
3134        B::silu_mul_batched(
3135            ctx,
3136            &scratch.gate_out_stacked,
3137            &scratch.up_out_stacked,
3138            &mut scratch.silu_stacked,
3139            total_pairs,
3140            inter,
3141        )?;
3142        stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
3143    }
3144
3145    // 5. Batched down gemv — src1 = silu_stacked [m, top_k, ffn]: each
3146    //    pair has its own row, outer = top_k * ffn, inner = ffn.
3147    let t0 = stage_t0();
3148    B::gemv_quant_moe_id_batched(
3149        ctx,
3150        &scratch.silu_stacked,
3151        down_stacked,
3152        &scratch.selected_ids_buf,
3153        &mut scratch.down_out_stacked,
3154        tokens,
3155        top_k,
3156        top_k * inter, // outer: top_k * ffn floats per token
3157        inter,         // inner: ffn floats per slot
3158    )?;
3159    stage_end(t0, ctx, &MOE_BATCHED_DECODE_DOWN_US);
3160
3161    // 6. Per-token weighted sum across slots → moe_out [m, h]. Caller
3162    //    does residual += moe_out at the end of forward_layer.
3163    let t0 = stage_t0();
3164    B::weighted_sum_batched(
3165        ctx,
3166        &scratch.down_out_stacked,
3167        &scratch.weights_2d,
3168        &mut scratch.moe_out,
3169        tokens,
3170        top_k,
3171        h,
3172    )?;
3173    stage_end(t0, ctx, &MOE_BATCHED_DECODE_WSUM_US);
3174
3175    Ok(())
3176}
3177
3178/// Build a stub Linear<B> with the given shape but zero weights. Used to
3179/// fill the dense `gate_up_proj` / `down_proj` slots in `LlamaFamilyLayer`
3180/// for MoE models — those slots are never invoked because the MoE FFN
3181/// path runs through `moe_layer.experts` instead. The stub's only purpose
3182/// is to satisfy the struct's type signature with minimal memory cost.
3183fn stub_linear<B: Backend>(
3184    out_features: usize,
3185    in_features: usize,
3186) -> Box<dyn ferrum_quantization::Linear<B>> {
3187    // Zero-init: out_features * in_features f32. For a 30B-A3B layer
3188    // this is 2*768*2048 = 3.1M f32 = 12 MB → fine; per-layer overhead
3189    // ≈ 12 MB × 48 = 576 MB. Marginal vs the experts (~16 GB).
3190    let zeros = vec![0.0f32; out_features * in_features];
3191    Box::new(ferrum_quantization::DenseLinear::<B>::from_rows(
3192        &zeros,
3193        out_features,
3194        in_features,
3195    ))
3196}
3197
3198fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
3199    let hd = cfg.head_dim;
3200    let half = hd / 2;
3201    let max = cfg.max_seq_len;
3202    let mut cos = vec![0.0f32; max * half];
3203    let mut sin = vec![0.0f32; max * half];
3204    for pos in 0..max {
3205        for i in 0..half {
3206            let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
3207            let angle = pos as f64 * freq;
3208            cos[pos * half + i] = angle.cos() as f32;
3209            sin[pos * half + i] = angle.sin() as f32;
3210        }
3211    }
3212    RopeCache {
3213        cos: B::from_slice(&cos),
3214        sin: B::from_slice(&sin),
3215    }
3216}