Skip to main content

ferrum_models/models/
qwen3_moe.rs

1//! `Qwen3MoeModel<B>` — Qwen3-MoE family decoder (Qwen3-30B-A3B and friends).
2//!
3//! Architectural delta vs [`LlamaFamilyModel`]:
4//!   * Each transformer layer's FFN is a top-K MoE block instead of a
5//!     fused `gate_up_proj → silu → down_proj` MLP.
6//!     - One small router linear (`[hidden] → [num_experts]`) picks
7//!       top-K experts per token.
8//!     - Each expert is itself a fused `gate_up + down` MLP with the
9//!       same SwiGLU + RMSNorm structure as the dense path, just with
10//!       `expert_intermediate_size` (typically much smaller than the
11//!       dense `intermediate_size`).
12//!     - Output is the weight-summed combination of the K selected
13//!       expert outputs.
14//!   * Attention path is unchanged from dense Qwen3 (GQA + QK-norm + RoPE).
15//!
16//! Implementation re-uses the dense layer's attention machinery
17//! verbatim — RMSNorm, fused QKV, QK-norm + RoPE, KV cache append,
18//! flash attention, O-projection, residual + post-norm. The only new
19//! code is the MoE FFN block at the tail of each layer's forward.
20//!
21//! Memory model: experts are loaded as `QuantLinear<B>` per expert,
22//! slicing the on-disk 3-D `ffn_{gate,up,down}_exps.weight` tensors
23//! byte-wise so weights stay compressed (Q4_K / Q6_K). For a 32 GB
24//! Mac to run Qwen3-30B-A3B at all, this is non-negotiable: an
25//! eager-fp32 expert stack would weigh ~110 GB.
26
27use std::collections::HashMap;
28use std::sync::atomic::AtomicU64;
29use std::sync::OnceLock;
30
31use ferrum_kernels::backend::{Backend, KvCache};
32use ferrum_quantization::WeightLoader;
33use ferrum_types::{FerrumError, Result};
34
35use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
36use crate::models::llama_family::{LlamaFamilyConfig, LlamaFamilyLayer, RopeCache};
37use crate::moe::{moe_forward, ExpertStack};
38use crate::moe_config::Qwen3MoeConfig;
39
40// Decode-side per-op profile counters — same names as the dense path
41// so existing tooling (`FERRUM_DECODE_OP_PROFILE=1` log scrapers) keeps
42// working without a separate switch for MoE.
43static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
44static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
45static MOE_TIME_US: AtomicU64 = AtomicU64::new(0);
46static MOE_CALLS: AtomicU64 = AtomicU64::new(0);
47
48// Fine-grained decode-only counters, populated by
49// `moe_forward_stacked_decode_impl` when FERRUM_DECODE_OP_PROFILE is set.
50// Each is per-layer summed over the layers in one decode token; drained
51// at the bottom of `decode_internal`.
52static DEC_ROUTE_US: AtomicU64 = AtomicU64::new(0);
53static DEC_GATE_US: AtomicU64 = AtomicU64::new(0);
54static DEC_UP_US: AtomicU64 = AtomicU64::new(0);
55static DEC_SILU_US: AtomicU64 = AtomicU64::new(0);
56static DEC_DOWN_US: AtomicU64 = AtomicU64::new(0);
57static DEC_WSUM_US: AtomicU64 = AtomicU64::new(0);
58// Single-shot per decode token (not per-layer).
59static DEC_EMBED_US: AtomicU64 = AtomicU64::new(0);
60static DEC_FINAL_NORM_US: AtomicU64 = AtomicU64::new(0);
61static DEC_LM_HEAD_US: AtomicU64 = AtomicU64::new(0);
62
63// MoE batched-prefill sub-stage counters (gate / up / down mul_mm_id +
64// silu + weighted_sum + host topk). Same FERRUM_DECODE_OP_PROFILE gate.
65static MOE_PREFILL_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
66static MOE_PREFILL_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
67static MOE_PREFILL_GATE_US: AtomicU64 = AtomicU64::new(0);
68static MOE_PREFILL_GATE_CALLS: AtomicU64 = AtomicU64::new(0);
69static MOE_PREFILL_UP_US: AtomicU64 = AtomicU64::new(0);
70static MOE_PREFILL_UP_CALLS: AtomicU64 = AtomicU64::new(0);
71static MOE_PREFILL_SILU_US: AtomicU64 = AtomicU64::new(0);
72static MOE_PREFILL_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
73static MOE_PREFILL_DOWN_US: AtomicU64 = AtomicU64::new(0);
74static MOE_PREFILL_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
75static MOE_PREFILL_WSUM_US: AtomicU64 = AtomicU64::new(0);
76static MOE_PREFILL_WSUM_CALLS: AtomicU64 = AtomicU64::new(0);
77
78// MoE batched-DECODE sub-stage counters (small-m path that uses the
79// batched-pair GEMV in place of the per-token loop).
80static MOE_BATCHED_DECODE_ROUTE_US: AtomicU64 = AtomicU64::new(0);
81static MOE_BATCHED_DECODE_GATE_US: AtomicU64 = AtomicU64::new(0);
82static MOE_BATCHED_DECODE_UP_US: AtomicU64 = AtomicU64::new(0);
83static MOE_BATCHED_DECODE_SILU_US: AtomicU64 = AtomicU64::new(0);
84static MOE_BATCHED_DECODE_DOWN_US: AtomicU64 = AtomicU64::new(0);
85static MOE_BATCHED_DECODE_WSUM_US: AtomicU64 = AtomicU64::new(0);
86
87// Coarse stage counters for `forward_layer_batched_decode` so we can
88// see where the time goes without per-op instrumentation. Summed
89// across all layers in one decode_batch_internal call.
90static BD_DENSE_US: AtomicU64 = AtomicU64::new(0); // rms_norm + qkv_proj + split_qkv + o_proj + fused_add_rms_norm
91static BD_ATTN_PERITEM_US: AtomicU64 = AtomicU64::new(0); // the for-i in 0..m attention loop (incl. plumbing)
92static BD_MOE_US: AtomicU64 = AtomicU64::new(0); // router + MoE FFN + residual add
93static BD_LAYER_CALLS: AtomicU64 = AtomicU64::new(0);
94
95/// Per-layer MoE state: router linear (small) + per-expert MLP stack.
96pub struct Qwen3MoeLayerState<B: Backend> {
97    /// Router projection `[hidden] → [num_experts]` — tiny, never sparse,
98    /// always runs the full GEMV.
99    pub router: Box<dyn ferrum_quantization::Linear<B>>,
100    /// Per-expert weight stack. Each entry's `gate_up` is the fused
101    /// `[gate; up]` projection; `down` is the post-SwiGLU output proj.
102    pub experts: ExpertStack<B>,
103}
104
105/// Reusable scratch buffers for the MoE forward path. All sized at
106/// allocation time and reused across layers / forward calls.
107pub struct Qwen3MoeScratch<B: Backend> {
108    /// See [`crate::models::llama_family::LlamaFamilyScratch`] for the
109    /// attention scratch — we re-use those names verbatim.
110    pub residual: Option<B::Buffer>,
111    pub norm_out: B::Buffer,
112    pub qkv_out: B::Buffer,
113    pub q_buf: B::Buffer,
114    pub k_buf: B::Buffer,
115    pub v_buf: B::Buffer,
116    pub q_head_major: B::Buffer,
117    pub k_head_major: B::Buffer,
118    pub v_head_major: B::Buffer,
119    pub attn_head_major_out: B::Buffer,
120    pub attn_flat: B::Buffer,
121    pub o_proj_out: B::Buffer,
122
123    // ── MoE-specific scratch ─────────────────────────────────────────
124    /// Router logits for the whole batch: `[max_tokens, num_experts]`.
125    pub router_logits: B::Buffer,
126    /// Per-(token, expert) gate||up projection output — `[2 * expert_inter]`.
127    pub gate_up_buf: B::Buffer,
128    /// SiLU(gate) * up scratch — `[expert_inter]`.
129    pub silu_buf: B::Buffer,
130    /// Per-(token, expert) down-projection output — `[hidden]`.
131    pub down_buf: B::Buffer,
132    /// Per-token input row scratch — `[hidden]`. Holds the post-RMSNorm
133    /// activation slice that the per-(expert) gate_up gemv reads, kept
134    /// stable across the entire top_k loop for one token.
135    pub x_single: B::Buffer,
136    /// Per-token output accumulator — `[hidden]`. Holds the running
137    /// `Σ_k weight_k · expert_k(x[b])` sum that grows across the top_k
138    /// loop and is flushed to `moe_out[b]` once per token.
139    pub acc_buf: B::Buffer,
140    /// MoE output `[max_tokens, hidden]`. Zeroed each forward.
141    pub moe_out: B::Buffer,
142    /// Pre-allocated `[hidden]` zero scratch — `acc_buf` is reset to
143    /// this each token without going through `B::from_slice` on the
144    /// hot path.
145    pub zero_hidden: B::Buffer,
146
147    // ── MoE batched-fast-path scratch (Metal `gemv_q*kw_moe_id_f32` /
148    //    `gemm_q*kw_moe_id_f32`) ─────────────────────────────────────
149    //
150    // Sized for `max_tokens * top_k * X` so the same buffers cover both
151    // decode (m=1, uses the first `top_k * X` slice) and prefill
152    // (m>1, uses the full `max_tokens * top_k * X`). Decode-only
153    // workloads pay no extra memory because `max_tokens` was 1 there.
154    /// `[max_tokens * top_k * expert_inter]` — gate gemm output per pair.
155    pub gate_out_stacked: B::Buffer,
156    /// `[max_tokens * top_k * expert_inter]` — up gemm output per pair.
157    pub up_out_stacked: B::Buffer,
158    /// `[max_tokens * top_k * expert_inter]` — SiLU(gate)·up per pair.
159    pub silu_stacked: B::Buffer,
160    /// `[max_tokens * top_k * hidden]` — down gemm output per pair.
161    pub down_out_stacked: B::Buffer,
162    /// `[top_k]` i32 expert IDs for the current token (decode reuses;
163    /// prefill writes per-pair indices into `ids_2d` instead).
164    pub ids_buf: B::Buffer,
165    /// `[top_k]` f32 router combine weights for the current decode
166    /// token. Decode hot-path uses `write_f32_into` to update.
167    pub weights_buf: B::Buffer,
168    /// `[max_tokens * top_k]` i32 — flat selected-expert IDs from the
169    /// GPU router for the prefill batch. Consumed by `compute_ids_tpe_gpu`
170    /// to bucket pairs by expert into `tpe_buf` / `ids_2d`.
171    pub selected_ids_buf: B::Buffer,
172    /// `[3]` u32 indirect-dispatch args (`grid_x, grid_y, grid_z`) for
173    /// the gate / up MoE GEMM. Written by `compute_ids_tpe_gpu` so the
174    /// consumer GEMM grid covers exactly `max(tpe[e])` columns instead
175    /// of the worst-case `tokens * top_k`.
176    pub gate_up_args_buf: B::Buffer,
177    /// Same shape as `gate_up_args_buf` but for the down MoE GEMM
178    /// (different `grid_y` because down's `M = hidden_size` vs gate/up's
179    /// `M = expert_intermediate_size`).
180    pub down_args_buf: B::Buffer,
181    /// `[num_experts * max_per_expert_max]` i32 — per-expert pair
182    /// index lists for prefill 2-D mul_mm_id. `max_per_expert_max`
183    /// is bounded by `max_tokens * top_k` (worst-case: one expert
184    /// gets every pair). Sized at scratch alloc time.
185    pub ids_2d: B::Buffer,
186    /// `[num_experts]` i32 — `tpe[e]` = number of pairs assigned to
187    /// expert `e`. Companion to `ids_2d`.
188    pub tpe_buf: B::Buffer,
189    /// `[max_tokens * top_k]` f32 — combine weights per pair, in
190    /// natural `[batch, top_k]` layout for `weighted_sum_batched`.
191    pub weights_2d: B::Buffer,
192
193    // ── Final-token / lm_head outputs ────────────────────────────────
194    pub last_hidden: B::Buffer,
195    pub last_normed: B::Buffer,
196    pub logits: B::Buffer,
197    pub batch_logits: B::Buffer,
198
199    // ── Per-item single-token buffers for decode_batch (Phase 4b) ────
200    //
201    // The batched-decode path runs M GEMMs at m=M (qkv_proj / o_proj /
202    // router / MoE expert mul_mm_id) but attention stays a per-item loop
203    // (each cache_id has its own contiguous K/V buffer — no way to fan
204    // M items into a single attention dispatch without paged KV). These
205    // 1-token-shaped scratches hold the per-item slice during the loop:
206    // `copy_slice` extracts q/k/v from the batched buffers, qk_norm_rope
207    // writes head-major into _single, kv_cache_append + flash_attention
208    // run on it, then copy_slice writes back into attn_flat[i*q_dim].
209    //
210    // None until `enable_batched_decode_scratch` is called from
211    // `ensure_kv` once we know we'll be doing multi-seq decode.
212    pub q_single: Option<B::Buffer>,
213    pub k_single: Option<B::Buffer>,
214    pub v_single: Option<B::Buffer>,
215    pub q_head_major_single: Option<B::Buffer>,
216    pub k_head_major_single: Option<B::Buffer>,
217    pub v_head_major_single: Option<B::Buffer>,
218    pub attn_head_major_single: Option<B::Buffer>,
219
220    // ── Paged batched dispatch scratch ──────────────────────────────────
221    //
222    // Mirrors the same fields on `LlamaFamilyScratch`. `Some` only when
223    // `FERRUM_METAL_PAGED_KV=1` and `enable_paged_batch` was called once
224    // we know the pool dimensions. Sized for `FERRUM_PAGED_MAX_SEQS ×
225    // q_dim` so the multi-seq decode path can fan in M items' Q into a
226    // single batched buffer for one `paged_decode_attention(num_seqs=M)`
227    // call instead of running M sequential m=1 attentions.
228    pub paged_batch_q: Option<B::Buffer>,
229    pub paged_batch_o: Option<B::Buffer>,
230    pub paged_batch_block_tables: Option<B::Buffer>,
231    pub paged_batch_context_lens: Option<B::Buffer>,
232    pub paged_max_blocks_per_seq: usize,
233
234    pub max_tokens: usize,
235}
236
237impl<B: Backend> Qwen3MoeScratch<B> {
238    fn alloc(cfg: &Qwen3MoeConfig, max_tokens: usize) -> Self {
239        let h = cfg.base.hidden_size;
240        let q_dim = cfg.base.num_heads * cfg.base.head_dim;
241        let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
242        let qkv_dim = q_dim + 2 * kv_dim;
243        let t = max_tokens;
244        let inter = cfg.expert_intermediate_size;
245        let n_exp = cfg.num_experts;
246        let vocab = cfg.base.vocab_size;
247        Self {
248            residual: Some(B::alloc(t * h)),
249            norm_out: B::alloc(t * h),
250            qkv_out: B::alloc(t * qkv_dim),
251            q_buf: B::alloc(t * q_dim),
252            k_buf: B::alloc(t * kv_dim),
253            v_buf: B::alloc(t * kv_dim),
254            q_head_major: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
255            k_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
256            v_head_major: B::alloc(cfg.base.num_kv_heads * t * cfg.base.head_dim),
257            attn_head_major_out: B::alloc(cfg.base.num_heads * t * cfg.base.head_dim),
258            attn_flat: B::alloc(t * q_dim),
259            o_proj_out: B::alloc(t * h),
260            router_logits: B::alloc(t * n_exp),
261            gate_up_buf: B::alloc(2 * inter),
262            silu_buf: B::alloc(inter),
263            down_buf: B::alloc(h),
264            x_single: B::alloc(h),
265            acc_buf: B::alloc(h),
266            moe_out: B::alloc(t * h),
267            zero_hidden: B::from_slice(&vec![0.0f32; h]),
268            gate_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
269            up_out_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
270            silu_stacked: B::alloc(t * cfg.num_experts_per_tok * inter),
271            down_out_stacked: B::alloc(t * cfg.num_experts_per_tok * h),
272            ids_buf: B::from_slice_i32(&vec![0i32; cfg.num_experts_per_tok]),
273            weights_buf: B::from_slice(&vec![0.0f32; cfg.num_experts_per_tok]),
274            selected_ids_buf: B::from_slice_i32(&vec![0i32; t * cfg.num_experts_per_tok]),
275            // 3 u32s per indirect args buffer; allocated as 3 i32s so we
276            // can reuse `from_slice_i32`. The kernel writes them as
277            // `device uint *` and the bit pattern is consumed by
278            // `dispatch_thread_groups_indirect`.
279            gate_up_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
280            down_args_buf: B::from_slice_i32(&[0i32, 0, 0]),
281            ids_2d: B::from_slice_i32(&vec![0i32; n_exp * t * cfg.num_experts_per_tok]),
282            tpe_buf: B::from_slice_i32(&vec![0i32; n_exp]),
283            weights_2d: B::from_slice(&vec![0.0f32; t * cfg.num_experts_per_tok]),
284            last_hidden: B::alloc(h),
285            last_normed: B::alloc(h),
286            logits: B::alloc(vocab),
287            batch_logits: B::alloc(t * vocab),
288            // Lazily-allocated; `enable_batched_decode_scratch` populates
289            // these the first time decode_batch is called with M > 1.
290            q_single: None,
291            k_single: None,
292            v_single: None,
293            q_head_major_single: None,
294            k_head_major_single: None,
295            v_head_major_single: None,
296            attn_head_major_single: None,
297            // Lazily-allocated; `enable_paged_batch` populates these when
298            // FERRUM_METAL_PAGED_KV=1 + we know the pool dimensions.
299            paged_batch_q: None,
300            paged_batch_o: None,
301            paged_batch_block_tables: None,
302            paged_batch_context_lens: None,
303            paged_max_blocks_per_seq: 0,
304            max_tokens: t,
305        }
306    }
307
308    /// Allocate scratch for paged batched dispatch. Mirrors
309    /// `LlamaFamilyScratch::enable_paged_batch`. Idempotent.
310    fn enable_paged_batch(
311        &mut self,
312        cfg: &Qwen3MoeConfig,
313        max_seqs: usize,
314        max_blocks_per_seq: usize,
315    ) {
316        if self.paged_batch_q.is_some() {
317            return;
318        }
319        let q_dim = cfg.base.num_heads * cfg.base.head_dim;
320        self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
321        self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
322        self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
323        self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
324        self.paged_max_blocks_per_seq = max_blocks_per_seq;
325    }
326
327    /// Allocate the per-item single-token scratch buffers used by
328    /// `forward_layer_batched_decode`. Idempotent.
329    fn enable_batched_decode_scratch(&mut self, cfg: &Qwen3MoeConfig) {
330        if self.q_single.is_some() {
331            return;
332        }
333        let q_dim = cfg.base.num_heads * cfg.base.head_dim;
334        let kv_dim = cfg.base.num_kv_heads * cfg.base.head_dim;
335        self.q_single = Some(B::alloc(q_dim));
336        self.k_single = Some(B::alloc(kv_dim));
337        self.v_single = Some(B::alloc(kv_dim));
338        self.q_head_major_single = Some(B::alloc(q_dim));
339        self.k_head_major_single = Some(B::alloc(kv_dim));
340        self.v_head_major_single = Some(B::alloc(kv_dim));
341        self.attn_head_major_single = Some(B::alloc(q_dim));
342    }
343}
344
345/// Qwen3-MoE decoder model.
346///
347/// Holds the same per-layer attention weights as [`LlamaFamilyModel`]
348/// plus a [`Qwen3MoeLayerState`] per layer for the MoE FFN. Routing,
349/// expert dispatch, and weighted combine all happen inside
350/// [`moe_forward`]; this struct only owns the storage and orchestrates
351/// the per-layer call sequence.
352pub struct Qwen3MoeModel<B: Backend> {
353    pub cfg: Qwen3MoeConfig,
354    pub runtime_cfg: LlmRuntimeConfig,
355
356    pub embed: B::Buffer,
357    /// Per-layer attention weights (re-uses dense `LlamaFamilyLayer`).
358    pub attn_layers: Vec<LlamaFamilyLayer<B>>,
359    /// Per-layer MoE state (router + expert stack).
360    pub moe_layers: Vec<Qwen3MoeLayerState<B>>,
361    pub final_norm_w: B::Buffer,
362    pub lm_head: Box<dyn ferrum_quantization::Linear<B>>,
363
364    pub rope: RopeCache<B>,
365    pub scratch: Qwen3MoeScratch<B>,
366
367    pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
368    kv_free_pool: Vec<Vec<KvCache<B>>>,
369
370    // ── Paged-KV multi-seq state ────────────────────────────────────────
371    //
372    // Mirrors `LlamaFamilyModel`. Only populated when
373    // `FERRUM_METAL_PAGED_KV=1`. Kv_caches entries become metadata-only
374    // views (block_table + context_lens) into the shared `paged_pools`.
375    pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
376    pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
377}
378
379impl<B: Backend> Qwen3MoeModel<B> {
380    /// Build a Qwen3-MoE model from a generic `WeightLoader<B>` plus a
381    /// GGUF reader for the experts (which `WeightLoader` doesn't model
382    /// directly — its API is rank-2 only).
383    ///
384    /// `loader` provides: token embedding, attention projections, layer
385    /// norms, lm_head — all the rank-2 weights.
386    /// `gguf` provides: the rank-3 expert tensors, sliced per-expert
387    /// inside [`ExpertStack::load_from_gguf`].
388    pub fn new(
389        cfg: Qwen3MoeConfig,
390        loader: &dyn WeightLoader<B>,
391        gguf: &ferrum_quantization::gguf::GgufFile,
392    ) -> Result<Self> {
393        {
394            let mut ctx = B::new_context();
395            B::reset_graph(&mut ctx);
396        }
397        let rope = build_rope_cache::<B>(&cfg.base);
398        let scratch = Qwen3MoeScratch::alloc(&cfg, 1);
399
400        let embed = loader.load_tensor("model.embed_tokens.weight")?;
401
402        let mut attn_layers = Vec::with_capacity(cfg.base.num_layers);
403        let mut moe_layers = Vec::with_capacity(cfg.base.num_layers);
404        for li in 0..cfg.base.num_layers {
405            let prefix = format!("model.layers.{li}");
406            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
407            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
408            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
409            let post_ln_w =
410                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
411
412            // Dense gate_up_proj / down_proj are absent in MoE GGUFs —
413            // we synthesise stub Linears so the LlamaFamilyLayer struct
414            // type-checks. They're never invoked because forward_layer
415            // calls the MoE path. Cheap: tiny zero-sized DenseLinears.
416            let gate_up_proj: Box<dyn ferrum_quantization::Linear<B>> =
417                stub_linear::<B>(2 * cfg.expert_intermediate_size, cfg.base.hidden_size);
418            let down_proj: Box<dyn ferrum_quantization::Linear<B>> =
419                stub_linear::<B>(cfg.base.hidden_size, cfg.expert_intermediate_size);
420
421            let (q_norm_w, k_norm_w) = if cfg.base.has_qk_norm {
422                let q = loader
423                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
424                    .ok();
425                let k = loader
426                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
427                    .ok();
428                (q, k)
429            } else {
430                (None, None)
431            };
432
433            attn_layers.push(LlamaFamilyLayer {
434                input_ln_w,
435                qkv_proj,
436                q_norm_w,
437                k_norm_w,
438                o_proj,
439                post_ln_w,
440                gate_up_proj,
441                down_proj,
442            });
443
444            // Router lives at `model.layers.{li}.mlp.router.weight` in
445            // ferrum-name space (see ferrum_to_gguf mapping). It's a
446            // plain rank-2 linear so the standard loader path covers
447            // it without going through the MoE-specific GGUF helper.
448            let router = loader.load_linear(&format!("{prefix}.mlp.router"))?;
449            if router.in_features() != cfg.base.hidden_size {
450                return Err(FerrumError::model(format!(
451                    "router layer {li}: in_features {} != hidden {}",
452                    router.in_features(),
453                    cfg.base.hidden_size
454                )));
455            }
456            if router.out_features() != cfg.num_experts {
457                return Err(FerrumError::model(format!(
458                    "router layer {li}: out_features {} != num_experts {}",
459                    router.out_features(),
460                    cfg.num_experts
461                )));
462            }
463
464            let experts = ExpertStack::<B>::load_from_gguf(
465                gguf,
466                li,
467                cfg.num_experts,
468                cfg.base.hidden_size,
469                cfg.expert_intermediate_size,
470            )?;
471
472            moe_layers.push(Qwen3MoeLayerState { router, experts });
473        }
474
475        let final_norm_w = loader.load_tensor("model.norm.weight")?;
476        let lm_head = if loader.has_tensor("lm_head.weight") {
477            loader.load_linear("lm_head")?
478        } else {
479            // Tied embeddings — same as dense path.
480            tracing::info!(
481                "Qwen3MoeModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
482            );
483            loader.load_linear("model.embed_tokens")?
484        };
485
486        let runtime_cfg = cfg.base.to_runtime();
487        Ok(Self {
488            cfg,
489            runtime_cfg,
490            embed,
491            attn_layers,
492            moe_layers,
493            final_norm_w,
494            lm_head,
495            rope,
496            scratch,
497            kv_caches: HashMap::new(),
498            kv_free_pool: Vec::new(),
499            paged_pools: None,
500            paged_block_alloc: None,
501        })
502    }
503
504    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
505        if self.scratch.max_tokens < tokens {
506            {
507                let mut ctx = B::new_context();
508                B::reset_graph(&mut ctx);
509            }
510            self.scratch = Qwen3MoeScratch::alloc(&self.cfg, tokens);
511        }
512    }
513
514    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
515        if self.kv_caches.contains_key(cache_id) {
516            return;
517        }
518        let nkv = self.cfg.base.num_kv_heads;
519        let hd = self.cfg.base.head_dim;
520        // 512 in 0.7.2 — same value the published bench used to hit 79
521        // tok/s at c=16 on this exact MoE model. See
522        // `LlamaFamilyModel::ensure_kv` for the full rationale.
523        let model_max = self.cfg.base.max_seq_len;
524        const DEFAULT_KV_CAPACITY: usize = 512;
525        let max = std::env::var("FERRUM_KV_CAPACITY")
526            .ok()
527            .and_then(|s| s.parse::<usize>().ok())
528            .map(|cap| cap.min(model_max))
529            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
530
531        // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches caches into
532        // block-table-indirect layout. Mirrors LlamaFamilyModel's path so
533        // the existing `paged_decode_attention` Metal kernel can fire
534        // once at num_seqs=m for batched decode (replacing the per-item
535        // attention loop that currently dominates `attn_peritem` in the
536        // c=16 profile).
537        // Default ON when the backend supports paged-KV (Metal). Users
538        // can force off with `FERRUM_METAL_PAGED_KV=0`. The flag was
539        // opt-in pre-0.7.2; flipping the default so default `ferrum
540        // serve` matches the bench-quality numbers without requiring
541        // env-var knowledge.
542        let paged = std::env::var("FERRUM_METAL_PAGED_KV")
543            .map(|v| v != "0")
544            .unwrap_or_else(|_| B::supports_paged_kv());
545        const PAGED_BLOCK_SIZE: usize = 16;
546
547        // Default 32: covers c=16 burst with 2× headroom for the
548        // fresh-cache-id-per-request pattern that bench/server harnesses
549        // use. Pool memory unchanged from pre-0.7.2 default because
550        // DEFAULT_KV_CAPACITY dropped 4096 → 2048 in lockstep.
551        let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
552            .ok()
553            .and_then(|s| s.parse::<usize>().ok())
554            .unwrap_or(32);
555        let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
556        let total_pool_blocks = max_seqs * max_blocks_per_seq;
557
558        // Lazy-allocate the shared paged pools on the first paged
559        // ensure_kv call.
560        if paged && self.paged_pools.is_none() {
561            let mut pools = Vec::with_capacity(self.cfg.base.num_layers);
562            for _ in 0..self.cfg.base.num_layers {
563                let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
564                pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
565            }
566            self.paged_pools = Some(pools);
567            self.paged_block_alloc = Some(std::sync::Mutex::new(
568                crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
569            ));
570        }
571        if paged {
572            self.scratch
573                .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
574        }
575
576        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
577            (0..self.cfg.base.num_layers)
578                .map(|_| {
579                    if paged {
580                        // Paged mode: cache holds metadata only. K/V are
581                        // 1-element placeholders. Real data lives in
582                        // `self.paged_pools[li].{k,v}`.
583                        let mut block_table = B::alloc_u32(max_blocks_per_seq);
584                        let _ = &mut block_table; // suppress unused-mut on backends that no-op write_u32
585                        let mut context_lens = B::alloc_u32(1);
586                        let mut bt_ctx = B::new_context();
587                        B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
588                        B::sync(&mut bt_ctx);
589                        KvCache {
590                            k: B::alloc(1),
591                            v: B::alloc(1),
592                            len: 0,
593                            capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
594                            num_kv_heads: nkv,
595                            head_dim: hd,
596                            block_size: PAGED_BLOCK_SIZE,
597                            block_table: Some(block_table),
598                            context_lens: Some(context_lens),
599                            paged_block_indices: Vec::new(),
600                        }
601                    } else {
602                        KvCache {
603                            k: B::alloc(nkv * max * hd),
604                            v: B::alloc(nkv * max * hd),
605                            len: 0,
606                            capacity: max,
607                            num_kv_heads: nkv,
608                            head_dim: hd,
609                            block_size: 0,
610                            block_table: None,
611                            context_lens: None,
612                            paged_block_indices: Vec::new(),
613                        }
614                    }
615                })
616                .collect()
617        });
618
619        // Allocate physical blocks for THIS cache_id from the shared pool.
620        if paged {
621            let alloc_arc = self
622                .paged_block_alloc
623                .as_ref()
624                .expect("paged_block_alloc must be initialised when paged=true");
625            let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
626            let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
627                Ok(idx) => idx,
628                Err(e) => {
629                    drop(alloc);
630                    self.kv_free_pool.push(caches);
631                    eprintln!(
632                        "[ferrum] paged KV pool exhausted on ensure_kv for \
633                         cache_id={cache_id:?}: {e}. Increase \
634                         FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
635                         throttle concurrent requests.",
636                    );
637                    return;
638                }
639            };
640            let mut padded = block_indices.clone();
641            padded.resize(max_blocks_per_seq, 0);
642            let mut ctx_tmp = B::new_context();
643            for c in caches.iter_mut() {
644                if let Some(bt) = c.block_table.as_mut() {
645                    B::write_u32(&mut ctx_tmp, bt, &padded);
646                }
647                c.paged_block_indices = block_indices.clone();
648            }
649            B::sync(&mut ctx_tmp);
650        }
651
652        for c in caches.iter_mut() {
653            c.len = 0;
654            if let Some(cl) = c.context_lens.as_mut() {
655                let mut ctx_tmp = B::new_context();
656                B::write_u32(&mut ctx_tmp, cl, &[0u32]);
657                B::sync(&mut ctx_tmp);
658            }
659        }
660        self.kv_caches.insert(cache_id.to_string(), caches);
661    }
662
663    /// Run one full transformer layer (attention + MoE FFN).
664    pub(crate) fn forward_layer(
665        &mut self,
666        ctx: &mut B::Context,
667        li: usize,
668        cache_id: &str,
669        residual: &mut B::Buffer,
670        pos_offset: usize,
671        tokens: usize,
672        // If `Some(idx)` and we land on the decode fast path, fold the
673        // next layer's leading rms_norm into this layer's MoE tail
674        // (cross-layer norm fusion). The next layer's caller must pass
675        // `prev_did_norm_fusion = true` so it skips its own rms_norm.
676        next_layer_idx: Option<usize>,
677        // If `true`, skip step 1's input rms_norm — the previous
678        // layer's tail already populated `scratch.norm_out`.
679        prev_did_norm_fusion: bool,
680    ) -> Result<bool> {
681        let cfg_base = &self.cfg.base;
682        let h = cfg_base.hidden_size;
683        let nh = cfg_base.num_heads;
684        let nkv = cfg_base.num_kv_heads;
685        let hd = cfg_base.head_dim;
686        let eps = cfg_base.rms_norm_eps;
687        let q_dim = nh * hd;
688        let kv_dim = nkv * hd;
689        let attn_layer = &self.attn_layers[li];
690        let moe_layer = &self.moe_layers[li];
691
692        let attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
693            B::sync(ctx);
694            Some(std::time::Instant::now())
695        } else {
696            None
697        };
698
699        // 1. Input RMSNorm — skipped when the previous layer's MoE tail
700        //    fused this norm via `weighted_sum_residual_norm_stacked`.
701        if !prev_did_norm_fusion {
702            B::rms_norm(
703                ctx,
704                residual,
705                &attn_layer.input_ln_w,
706                eps,
707                &mut self.scratch.norm_out,
708                tokens,
709                h,
710            );
711        }
712
713        // 2. Fused QKV
714        attn_layer.qkv_proj.forward(
715            ctx,
716            &self.scratch.norm_out,
717            &mut self.scratch.qkv_out,
718            tokens,
719        );
720
721        // 3-4. Fused split-QKV + QK-norm + RoPE + head-major transpose.
722        //
723        // One Metal dispatch replaces (split_qkv → 3× qk_norm_rope), the
724        // four-launch chain that used to dominate the attention prelude.
725        // Reads qkv_out once, writes head-major Q/K (norm+RoPE) and V
726        // (transpose only) directly into attention scratch. Saves 3
727        // dispatches per layer (×48 = 144 dispatches per decode token).
728        //
729        // CPU and other backends without the fused kernel return
730        // Unsupported and we fall through to the original four-launch
731        // path. q_buf / k_buf / v_buf stay in scratch because that path
732        // and the per-expert MoE fallback still want them.
733        let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
734        let dummy = &attn_layer.input_ln_w;
735        let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy);
736        let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy);
737
738        // 5. Grab the per-layer KV cache up front — the deepest fused
739        //    variant writes K/V straight into it, avoiding a trailing
740        //    `kv_cache_append_head_major` dispatch.
741        //
742        // Paged mode: extract a raw pointer to the layer's pool buffers
743        // BEFORE the &mut cache borrow, so we can pass &mut to the
744        // paged kernel below without holding two simultaneous mutable
745        // borrows on `self`. Safety: `paged_pools` is allocated once at
746        // first ensure_kv call and never resized; the only concurrent
747        // mutation is the pool's own kernel writes (sequenced via
748        // command buffers), so the raw pointer remains valid for the
749        // duration of this layer call.
750        let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
751            if let Some(pools) = self.paged_pools.as_mut() {
752                let pool = &mut pools[li];
753                Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
754            } else {
755                None
756            };
757        let caches = self
758            .kv_caches
759            .get_mut(cache_id)
760            .expect("ensure_kv must be called before forward_layer");
761        let cache = &mut caches[li];
762        let cache_len_before = cache.len;
763        let cache_capacity = cache.capacity;
764
765        // Defense in depth: refuse to write past the KV buffer. Silent
766        // overflow has visible failure modes (garbage output, stale token
767        // attention, slowdowns from reading uninitialised memory). The
768        // graceful path is the caller pre-checking via `kv_capacity()` and
769        // either compacting or refusing the request; this panic only
770        // fires when that contract is broken.
771        if cache_len_before + tokens > cache_capacity {
772            panic!(
773                "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
774                cache_len_before + tokens
775            );
776        }
777
778        // Try the deepest fusion: fused split-QKV-norm-rope that writes
779        // K/V directly into the cache slot. Paged mode writes into the
780        // shared pool via block_table indirection; contiguous mode
781        // writes into the per-cache_id k/v buffers directly.
782        let used_qkv_into_cache = if cache.block_size > 0 {
783            // Paged path.
784            let bt = cache
785                .block_table
786                .as_ref()
787                .expect("paged cache missing block_table");
788            let num_blocks_per_seq = cache.capacity / cache.block_size;
789            let (pool_k_ptr, pool_v_ptr) =
790                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
791            // SAFETY: pools allocated-once, see paged_pool_ptr setup above.
792            let pool_k = unsafe { &mut *pool_k_ptr };
793            let pool_v = unsafe { &mut *pool_v_ptr };
794            B::split_qkv_norm_rope_into_paged_cache(
795                ctx,
796                &self.scratch.qkv_out,
797                0,
798                q_norm_w,
799                k_norm_w,
800                &self.rope.cos,
801                &self.rope.sin,
802                &mut self.scratch.q_head_major,
803                0,
804                pool_k,
805                pool_v,
806                bt,
807                tokens,
808                nh,
809                nkv,
810                hd,
811                pos_offset,
812                eps,
813                qk_mode,
814                cache_len_before,
815                cache.block_size,
816                num_blocks_per_seq,
817            )
818            .is_ok()
819        } else {
820            B::split_qkv_norm_rope_into_cache(
821                ctx,
822                &self.scratch.qkv_out,
823                q_norm_w,
824                k_norm_w,
825                &self.rope.cos,
826                &self.rope.sin,
827                &mut self.scratch.q_head_major,
828                &mut cache.k,
829                &mut cache.v,
830                tokens,
831                nh,
832                nkv,
833                hd,
834                pos_offset,
835                eps,
836                qk_mode,
837                cache_len_before,
838                cache_capacity,
839            )
840            .is_ok()
841        };
842        if !used_qkv_into_cache {
843            // Fallback 1: fused split-QKV-norm-rope to head-major scratch
844            // (Metal pre-decode-fusion path), then explicit cache append.
845            let used_fused_qkv = B::split_qkv_norm_rope(
846                ctx,
847                &self.scratch.qkv_out,
848                q_norm_w,
849                k_norm_w,
850                &self.rope.cos,
851                &self.rope.sin,
852                &mut self.scratch.q_head_major,
853                &mut self.scratch.k_head_major,
854                &mut self.scratch.v_head_major,
855                tokens,
856                nh,
857                nkv,
858                hd,
859                pos_offset,
860                eps,
861                qk_mode,
862            )
863            .is_ok();
864            if !used_fused_qkv {
865                // Fallback 2: original four-launch chain.
866                B::split_qkv(
867                    ctx,
868                    &self.scratch.qkv_out,
869                    &mut self.scratch.q_buf,
870                    &mut self.scratch.k_buf,
871                    &mut self.scratch.v_buf,
872                    tokens,
873                    q_dim,
874                    kv_dim,
875                );
876                B::qk_norm_rope(
877                    ctx,
878                    &self.scratch.q_buf,
879                    q_norm_w,
880                    &self.rope.cos,
881                    &self.rope.sin,
882                    &mut self.scratch.q_head_major,
883                    tokens,
884                    nh,
885                    hd,
886                    pos_offset,
887                    eps,
888                    qk_mode,
889                );
890                B::qk_norm_rope(
891                    ctx,
892                    &self.scratch.k_buf,
893                    k_norm_w,
894                    &self.rope.cos,
895                    &self.rope.sin,
896                    &mut self.scratch.k_head_major,
897                    tokens,
898                    nkv,
899                    hd,
900                    pos_offset,
901                    eps,
902                    qk_mode,
903                );
904                B::qk_norm_rope(
905                    ctx,
906                    &self.scratch.v_buf,
907                    dummy,
908                    &self.rope.cos,
909                    &self.rope.sin,
910                    &mut self.scratch.v_head_major,
911                    tokens,
912                    nkv,
913                    hd,
914                    pos_offset,
915                    eps,
916                    0,
917                );
918            }
919            B::kv_cache_append_head_major(
920                ctx,
921                &mut cache.k,
922                &mut cache.v,
923                cache.len,
924                cache.capacity,
925                &self.scratch.k_head_major,
926                &self.scratch.v_head_major,
927                tokens,
928                nkv,
929                hd,
930            );
931        }
932        cache.len += tokens;
933        let kv_len = cache.len;
934        let kv_stride = cache.capacity;
935
936        if cache.block_size > 0 {
937            // Paged decode: read from the shared pool via block_table.
938            let bt = cache
939                .block_table
940                .as_ref()
941                .expect("paged cache missing block_table");
942            let cl_buf = cache
943                .context_lens
944                .as_mut()
945                .expect("paged cache missing context_lens");
946            let num_blocks_per_seq = cache.capacity / cache.block_size;
947            let (pool_k_ptr, pool_v_ptr) =
948                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
949            // SAFETY: see paged_pool_ptr setup above.
950            let pool_k = unsafe { &*pool_k_ptr };
951            let pool_v = unsafe { &*pool_v_ptr };
952            let final_kv_len = cache.len as u32;
953            B::write_u32(ctx, cl_buf, &[final_kv_len]);
954            B::paged_decode_attention(
955                ctx,
956                &self.scratch.q_head_major,
957                pool_k,
958                pool_v,
959                &mut self.scratch.attn_head_major_out,
960                bt,
961                cl_buf,
962                1, // num_seqs (single-seq m=1 path)
963                nh,
964                nkv,
965                hd,
966                cache.block_size,
967                num_blocks_per_seq,
968                tokens,
969            )
970            .expect("paged_decode_attention");
971            let _ = kv_stride; // consumed by contig path only
972        } else {
973            let attn_cfg = ferrum_kernels::backend::AttnConfig {
974                num_heads: nh,
975                num_kv_heads: nkv,
976                head_dim: hd,
977                causal: true,
978                scale: 1.0 / (hd as f32).sqrt(),
979                kv_seq_stride: kv_stride,
980                sliding_window: cfg_base.sliding_window,
981            };
982            B::flash_attention(
983                ctx,
984                &self.scratch.q_head_major,
985                &cache.k,
986                &cache.v,
987                &mut self.scratch.attn_head_major_out,
988                1,
989                tokens,
990                kv_len,
991                pos_offset,
992                &attn_cfg,
993            );
994        }
995
996        if let Some(t0) = attn_t0 {
997            B::sync(ctx);
998            ATTN_TIME_US.fetch_add(
999                t0.elapsed().as_micros() as u64,
1000                std::sync::atomic::Ordering::Relaxed,
1001            );
1002            ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1003        }
1004
1005        // 7. transpose head-major → token-major.
1006        //
1007        // For tokens=1 the two layouts are byte-identical: both
1008        // collapse to the flat [heads * head_dim] vector at offset
1009        // `head*hd + d`. Skip the dispatch and point o_proj at
1010        // attn_head_major_out directly. Saves 1 dispatch per layer
1011        // (×48 = 48 dispatches per decode token) on Qwen3-30B-A3B.
1012        let attn_token_major = if tokens == 1 {
1013            &self.scratch.attn_head_major_out
1014        } else {
1015            B::transpose_head_to_token(
1016                ctx,
1017                &self.scratch.attn_head_major_out,
1018                &mut self.scratch.attn_flat,
1019                tokens,
1020                nh,
1021                hd,
1022            );
1023            &self.scratch.attn_flat
1024        };
1025
1026        // 8. O-proj.
1027        attn_layer
1028            .o_proj
1029            .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1030
1031        // 9. fused residual-add + post-attention RMSNorm.
1032        B::fused_add_rms_norm(
1033            ctx,
1034            residual,
1035            &self.scratch.o_proj_out,
1036            &attn_layer.post_ln_w,
1037            eps,
1038            &mut self.scratch.norm_out,
1039            tokens,
1040            h,
1041        );
1042
1043        // ── MoE FFN block ────────────────────────────────────────────
1044        let moe_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1045            B::sync(ctx);
1046            Some(std::time::Instant::now())
1047        } else {
1048            None
1049        };
1050
1051        // 10. Router gemv: norm_out [tokens, hidden] → router_logits [tokens, num_experts]
1052        moe_layer.router.forward(
1053            ctx,
1054            &self.scratch.norm_out,
1055            &mut self.scratch.router_logits,
1056            tokens,
1057        );
1058
1059        // 11. Per-(token, expert) MLP dispatch + weighted combine.
1060        //
1061        // Two paths:
1062        //   - **Batched fast path** (decode m=1, all stacked variants
1063        //     present): single `gemv_quant_moe_id` dispatch covers all
1064        //     8 selected expert × 1 token gate gemvs in parallel; same
1065        //     for up and down. Cuts per-layer expert dispatches from
1066        //     ~32 (8 × 4 ops/pair) to 4 (gate + up + silu + down + 1 acc).
1067        //     Routes Qwen3-30B-A3B decode close to llama.cpp's
1068        //     `kernel_mul_mm_id`.
1069        //   - **Per-(token, expert) fallback** via `moe_forward` —
1070        //     used for prefill (m > 1), or when the backend doesn't
1071        //     populate stacked variants (CPU, synthetic-MoE tests).
1072        let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
1073            && moe_layer.experts.up_stacked.is_some()
1074            && moe_layer.experts.down_stacked.is_some();
1075
1076        // Fast path for decode (tokens=1): the stacked decode impl
1077        // writes the weighted-sum result *directly* into `residual` via
1078        // `weighted_sum_residual_stacked`, skipping the moe_out scratch
1079        // and the trailing `add_inplace`. Saves 1 dispatch per layer.
1080        // Prefill (m>1) and the per-expert fallback still go through
1081        // moe_out + add_inplace.
1082        let decode_fast_path = stacked_path_available && tokens == 1;
1083        // Cross-layer fusion: when on the decode fast path AND there is
1084        // a next layer, fold its leading rms_norm into this layer's
1085        // tail (`weighted_sum_residual_norm_stacked`). Returns whether
1086        // the fusion ran so the caller can signal the next layer to
1087        // skip its standalone rms_norm.
1088        let did_norm_fusion = decode_fast_path && next_layer_idx.is_some();
1089
1090        if stacked_path_available {
1091            if tokens > 1 {
1092                // Prefill: one batched 2-D mul_mm_id covers all
1093                // (token, expert) pairs in parallel.
1094                self.moe_forward_batched_prefill(ctx, li, tokens)?;
1095            } else {
1096                // Decode m=1: dedicated per-token path that fuses
1097                // residual-add into the final weighted-sum, and
1098                // optionally folds the next layer's rms_norm in too.
1099                self.moe_forward_stacked(ctx, li, tokens, residual, next_layer_idx)?;
1100            }
1101        } else {
1102            moe_forward::<B>(
1103                ctx,
1104                &self.scratch.norm_out,
1105                &self.scratch.router_logits,
1106                &mut self.scratch.moe_out,
1107                tokens,
1108                h,
1109                self.cfg.expert_intermediate_size,
1110                self.cfg.num_experts,
1111                self.cfg.num_experts_per_tok,
1112                self.cfg.norm_topk_prob,
1113                &moe_layer.experts,
1114                &mut self.scratch.x_single,
1115                &mut self.scratch.acc_buf,
1116                &mut self.scratch.gate_up_buf,
1117                &mut self.scratch.silu_buf,
1118                &mut self.scratch.down_buf,
1119                &self.scratch.zero_hidden,
1120            )?;
1121        }
1122
1123        // 12. residual += moe_out (skipped on decode fast path — already
1124        //     accumulated by `weighted_sum_residual_stacked`).
1125        if !decode_fast_path {
1126            B::add_inplace(ctx, residual, &self.scratch.moe_out, tokens * h);
1127        }
1128
1129        if let Some(t0) = moe_t0 {
1130            B::sync(ctx);
1131            MOE_TIME_US.fetch_add(
1132                t0.elapsed().as_micros() as u64,
1133                std::sync::atomic::Ordering::Relaxed,
1134            );
1135            MOE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1136        }
1137
1138        Ok(did_norm_fusion)
1139    }
1140
1141    fn moe_forward_stacked(
1142        &mut self,
1143        ctx: &mut B::Context,
1144        li: usize,
1145        tokens: usize,
1146        residual: &mut B::Buffer,
1147        next_layer_idx: Option<usize>,
1148    ) -> Result<()> {
1149        let cfg = &self.cfg;
1150        // `next_norm_w` is the next layer's `attn_layer.input_ln_w`.
1151        // We can't borrow `self.attn_layers[idx]` and pass &mut
1152        // self.scratch to the impl simultaneously, so collect the raw
1153        // pointer here. Safety: forward_layer holds &mut self for the
1154        // call; the borrow scopes are fully sequential.
1155        let next_norm_w_ptr: Option<*const B::Buffer> =
1156            next_layer_idx.map(|idx| &self.attn_layers[idx].input_ln_w as *const _);
1157        // SAFETY: pointer dereference is valid because:
1158        //   * The buffer lives in `self.attn_layers[idx]` which we
1159        //     borrowed immutably to take the pointer. We do not mutate
1160        //     `self.attn_layers` while `next_norm_w_ptr` is in use.
1161        //   * `&mut self.scratch` and `&self.moe_layers[li]` are disjoint
1162        //     fields from `self.attn_layers` so this is safe.
1163        let next_norm_w: Option<&B::Buffer> = next_norm_w_ptr.map(|p| unsafe { &*p });
1164        moe_forward_stacked_decode_impl::<B>(
1165            ctx,
1166            &self.moe_layers[li],
1167            &mut self.scratch,
1168            cfg.base.hidden_size,
1169            cfg.expert_intermediate_size,
1170            cfg.num_experts_per_tok,
1171            cfg.num_experts,
1172            cfg.norm_topk_prob,
1173            tokens,
1174            residual,
1175            next_norm_w,
1176            cfg.base.rms_norm_eps,
1177        )
1178    }
1179
1180    fn moe_forward_batched_prefill(
1181        &mut self,
1182        ctx: &mut B::Context,
1183        li: usize,
1184        tokens: usize,
1185    ) -> Result<()> {
1186        let cfg = &self.cfg;
1187        moe_forward_batched_prefill_impl::<B>(
1188            ctx,
1189            &self.moe_layers[li],
1190            &mut self.scratch,
1191            cfg.base.hidden_size,
1192            cfg.expert_intermediate_size,
1193            cfg.num_experts_per_tok,
1194            cfg.num_experts,
1195            cfg.norm_topk_prob,
1196            tokens,
1197        )
1198    }
1199
1200    /// Prefill: process `tokens` prompt tokens, return last-token logits.
1201    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1202        let seq_len = tokens.len();
1203        assert!(seq_len > 0);
1204        self.ensure_scratch(seq_len);
1205        self.ensure_kv(cache_id);
1206
1207        let pos_offset = self
1208            .kv_caches
1209            .get(cache_id)
1210            .and_then(|layers| layers.first())
1211            .map(|c| c.len)
1212            .unwrap_or(0);
1213
1214        let h = self.cfg.base.hidden_size;
1215        let vocab = self.cfg.base.vocab_size;
1216        let mut ctx = B::new_context();
1217
1218        // FERRUM_DECODE_OP_PROFILE doubles as the prefill-profile gate
1219        // for Qwen3-MoE: when set, dump (attn-us, moe-us, total-us) at
1220        // the end of prefill so we can attribute the prefill bottleneck
1221        // between attention and MoE.
1222        let prefill_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1223            B::sync(&mut ctx);
1224            for c in [
1225                &ATTN_TIME_US,
1226                &ATTN_CALLS,
1227                &MOE_TIME_US,
1228                &MOE_CALLS,
1229                &MOE_PREFILL_HOST_TOPK_US,
1230                &MOE_PREFILL_HOST_TOPK_CALLS,
1231                &MOE_PREFILL_GATE_US,
1232                &MOE_PREFILL_GATE_CALLS,
1233                &MOE_PREFILL_UP_US,
1234                &MOE_PREFILL_UP_CALLS,
1235                &MOE_PREFILL_SILU_US,
1236                &MOE_PREFILL_SILU_CALLS,
1237                &MOE_PREFILL_DOWN_US,
1238                &MOE_PREFILL_DOWN_CALLS,
1239                &MOE_PREFILL_WSUM_US,
1240                &MOE_PREFILL_WSUM_CALLS,
1241            ] {
1242                c.store(0, std::sync::atomic::Ordering::Relaxed);
1243            }
1244            Some(std::time::Instant::now())
1245        } else {
1246            None
1247        };
1248
1249        let mut residual = self
1250            .scratch
1251            .residual
1252            .take()
1253            .expect("scratch residual missing (previous call didn't restore)");
1254        B::embedding_lookup(&mut ctx, &self.embed, tokens, &mut residual, h);
1255
1256        // For prefill (seq_len > 1) the cross-layer norm fusion does
1257        // not apply (it lives on the decode fast path). We still pass
1258        // `next_layer_idx = None` so forward_layer emits the regular
1259        // tail.
1260        let mut prev_did_norm_fusion = false;
1261        let num_layers = self.cfg.base.num_layers;
1262        for li in 0..num_layers {
1263            let next_layer_idx = if li + 1 < num_layers {
1264                Some(li + 1)
1265            } else {
1266                None
1267            };
1268            prev_did_norm_fusion = self
1269                .forward_layer(
1270                    &mut ctx,
1271                    li,
1272                    cache_id,
1273                    &mut residual,
1274                    pos_offset,
1275                    seq_len,
1276                    next_layer_idx,
1277                    prev_did_norm_fusion,
1278                )
1279                .expect("forward_layer");
1280        }
1281
1282        // Last-token slice → final RMSNorm → lm_head.
1283        B::copy_slice(
1284            &mut ctx,
1285            &residual,
1286            (seq_len - 1) * h,
1287            &mut self.scratch.last_hidden,
1288            0,
1289            h,
1290        );
1291        B::rms_norm(
1292            &mut ctx,
1293            &self.scratch.last_hidden,
1294            &self.final_norm_w,
1295            self.cfg.base.rms_norm_eps,
1296            &mut self.scratch.last_normed,
1297            1,
1298            h,
1299        );
1300        self.lm_head.forward(
1301            &mut ctx,
1302            &self.scratch.last_normed,
1303            &mut self.scratch.logits,
1304            1,
1305        );
1306
1307        B::sync(&mut ctx);
1308        if let Some(t0) = prefill_t0 {
1309            let total_us = t0.elapsed().as_micros() as u64;
1310            let attn_us = ATTN_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
1311            let attn_n = ATTN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1312            let moe_us = MOE_TIME_US.load(std::sync::atomic::Ordering::Relaxed);
1313            let moe_n = MOE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1314            let other_us = total_us.saturating_sub(attn_us).saturating_sub(moe_us);
1315            eprintln!(
1316                "[prefill-profile] tokens={seq_len} total={} ms ({:.0} t/s)",
1317                total_us / 1000,
1318                seq_len as f64 * 1e6 / total_us as f64
1319            );
1320            let bucket = |label: &str, n: u64, us: u64| {
1321                if n > 0 {
1322                    eprintln!(
1323                        "  {label:>6}: {:7} ms ({:5.1}%) over {n:4} calls",
1324                        us / 1000,
1325                        us as f64 * 100.0 / total_us as f64
1326                    );
1327                }
1328            };
1329            bucket("attn", attn_n, attn_us);
1330            bucket("moe", moe_n, moe_us);
1331            bucket("other", 1, other_us);
1332            // MoE sub-stages — show as % of total prefill time so they
1333            // reconcile against the `moe` bucket above.
1334            let host_us = MOE_PREFILL_HOST_TOPK_US.load(std::sync::atomic::Ordering::Relaxed);
1335            let gate_us = MOE_PREFILL_GATE_US.load(std::sync::atomic::Ordering::Relaxed);
1336            let up_us = MOE_PREFILL_UP_US.load(std::sync::atomic::Ordering::Relaxed);
1337            let silu_us = MOE_PREFILL_SILU_US.load(std::sync::atomic::Ordering::Relaxed);
1338            let down_us = MOE_PREFILL_DOWN_US.load(std::sync::atomic::Ordering::Relaxed);
1339            let wsum_us = MOE_PREFILL_WSUM_US.load(std::sync::atomic::Ordering::Relaxed);
1340            let host_n = MOE_PREFILL_HOST_TOPK_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1341            let gate_n = MOE_PREFILL_GATE_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1342            let up_n = MOE_PREFILL_UP_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1343            let silu_n = MOE_PREFILL_SILU_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1344            let down_n = MOE_PREFILL_DOWN_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1345            let wsum_n = MOE_PREFILL_WSUM_CALLS.load(std::sync::atomic::Ordering::Relaxed);
1346            bucket("  host", host_n, host_us);
1347            bucket("  gate", gate_n, gate_us);
1348            bucket("  up", up_n, up_us);
1349            bucket("  silu", silu_n, silu_us);
1350            bucket("  down", down_n, down_us);
1351            bucket("  wsum", wsum_n, wsum_us);
1352        }
1353        self.scratch.residual = Some(residual);
1354        B::to_vec(&self.scratch.logits, vocab)
1355    }
1356
1357    /// Decode: 1 token at position `pos`, return next-step logits.
1358    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1359        self.ensure_scratch(1);
1360        self.ensure_kv(cache_id);
1361
1362        let h = self.cfg.base.hidden_size;
1363        let vocab = self.cfg.base.vocab_size;
1364        let mut ctx = B::new_context();
1365
1366        let decode_t0 = if std::env::var("FERRUM_MOE_PROFILE").is_ok() {
1367            Some(std::time::Instant::now())
1368        } else {
1369            None
1370        };
1371
1372        // FERRUM_DECODE_OP_PROFILE gates the per-stage breakdown emitted
1373        // at the bottom of every decode token. Reuses the same atomic
1374        // counters that `forward_layer` already populates (ATTN_TIME_US,
1375        // MOE_TIME_US — drained here per-token instead of per-prefill).
1376        let stage_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1377            B::sync(&mut ctx);
1378            for c in [
1379                &ATTN_TIME_US,
1380                &ATTN_CALLS,
1381                &MOE_TIME_US,
1382                &MOE_CALLS,
1383                &DEC_ROUTE_US,
1384                &DEC_GATE_US,
1385                &DEC_UP_US,
1386                &DEC_SILU_US,
1387                &DEC_DOWN_US,
1388                &DEC_WSUM_US,
1389                &DEC_EMBED_US,
1390                &DEC_FINAL_NORM_US,
1391                &DEC_LM_HEAD_US,
1392            ] {
1393                c.store(0, std::sync::atomic::Ordering::Relaxed);
1394            }
1395            Some(std::time::Instant::now())
1396        } else {
1397            None
1398        };
1399        let prof = stage_t0.is_some();
1400        let mark = |ctx: &mut B::Context, c: &AtomicU64, t0: std::time::Instant| {
1401            if prof {
1402                B::sync(ctx);
1403                c.fetch_add(
1404                    t0.elapsed().as_micros() as u64,
1405                    std::sync::atomic::Ordering::Relaxed,
1406                );
1407            }
1408        };
1409        let mt0 = std::time::Instant::now();
1410
1411        let mut residual = self
1412            .scratch
1413            .residual
1414            .take()
1415            .expect("scratch residual missing (previous call didn't restore)");
1416        let t0 = std::time::Instant::now();
1417        B::embedding_lookup(&mut ctx, &self.embed, &[token], &mut residual, h);
1418        mark(&mut ctx, &DEC_EMBED_US, t0);
1419        let _ = mt0; // silence if unused on non-profile builds
1420
1421        // Cross-layer rms_norm fusion: layer L's MoE tail folds the
1422        // next layer's leading rms_norm into its weighted-sum-residual
1423        // when the decode fast path applies. The flag carries forward.
1424        let mut prev_did_norm_fusion = false;
1425        let num_layers = self.cfg.base.num_layers;
1426        for li in 0..num_layers {
1427            let next_layer_idx = if li + 1 < num_layers {
1428                Some(li + 1)
1429            } else {
1430                None
1431            };
1432            prev_did_norm_fusion = self
1433                .forward_layer(
1434                    &mut ctx,
1435                    li,
1436                    cache_id,
1437                    &mut residual,
1438                    pos as usize,
1439                    1,
1440                    next_layer_idx,
1441                    prev_did_norm_fusion,
1442                )
1443                .expect("forward_layer");
1444        }
1445
1446        let t0 = std::time::Instant::now();
1447        B::rms_norm(
1448            &mut ctx,
1449            &residual,
1450            &self.final_norm_w,
1451            self.cfg.base.rms_norm_eps,
1452            &mut self.scratch.last_normed,
1453            1,
1454            h,
1455        );
1456        mark(&mut ctx, &DEC_FINAL_NORM_US, t0);
1457
1458        let t0 = std::time::Instant::now();
1459        self.lm_head.forward(
1460            &mut ctx,
1461            &self.scratch.last_normed,
1462            &mut self.scratch.logits,
1463            1,
1464        );
1465        mark(&mut ctx, &DEC_LM_HEAD_US, t0);
1466
1467        B::sync(&mut ctx);
1468        self.scratch.residual = Some(residual);
1469
1470        // FERRUM_DECODE_OP_PROFILE: per-token decode breakdown.
1471        if let Some(t0) = stage_t0 {
1472            use std::sync::atomic::Ordering;
1473            let total_us = t0.elapsed().as_micros() as u64;
1474            let attn_us = ATTN_TIME_US.swap(0, Ordering::Relaxed);
1475            let moe_us = MOE_TIME_US.swap(0, Ordering::Relaxed);
1476            let route = DEC_ROUTE_US.swap(0, Ordering::Relaxed);
1477            let gate = DEC_GATE_US.swap(0, Ordering::Relaxed);
1478            let up = DEC_UP_US.swap(0, Ordering::Relaxed);
1479            let silu = DEC_SILU_US.swap(0, Ordering::Relaxed);
1480            let down = DEC_DOWN_US.swap(0, Ordering::Relaxed);
1481            let wsum = DEC_WSUM_US.swap(0, Ordering::Relaxed);
1482            let embed = DEC_EMBED_US.swap(0, Ordering::Relaxed);
1483            let fnorm = DEC_FINAL_NORM_US.swap(0, Ordering::Relaxed);
1484            let lmhead = DEC_LM_HEAD_US.swap(0, Ordering::Relaxed);
1485            let other = total_us.saturating_sub(attn_us + moe_us + embed + fnorm + lmhead);
1486            let pct = |us: u64| -> f64 {
1487                if total_us == 0 {
1488                    0.0
1489                } else {
1490                    100.0 * us as f64 / total_us as f64
1491                }
1492            };
1493            eprintln!(
1494                "[decode-prof] total={} ms | attn={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | embed={} fnorm={} lmhead={} other={} ({:.1}%)",
1495                total_us / 1000,
1496                attn_us / 1000, pct(attn_us),
1497                moe_us / 1000, pct(moe_us),
1498                route / 1000, gate / 1000, up / 1000, silu / 1000, down / 1000, wsum / 1000,
1499                embed / 1000, fnorm / 1000, lmhead / 1000,
1500                other / 1000, pct(other),
1501            );
1502        }
1503
1504        // Drain MoE per-op counters every decode step. The counters
1505        // accumulate across all 48 layers; printing per-step gives a
1506        // per-token breakdown.
1507        if let Some(t0) = decode_t0 {
1508            use crate::moe::dispatch::*;
1509            use std::sync::atomic::Ordering;
1510            let total_us = t0.elapsed().as_micros() as u64;
1511            let sync_us = MOE_SYNC_US.swap(0, Ordering::Relaxed);
1512            let sync_n = MOE_SYNC_CALLS.swap(0, Ordering::Relaxed);
1513            let topk_us = MOE_HOST_TOPK_US.swap(0, Ordering::Relaxed);
1514            let topk_n = MOE_HOST_TOPK_CALLS.swap(0, Ordering::Relaxed);
1515            let gu_us = MOE_GEMV_GATE_UP_US.swap(0, Ordering::Relaxed);
1516            let gu_n = MOE_GEMV_GATE_UP_CALLS.swap(0, Ordering::Relaxed);
1517            let silu_us = MOE_SILU_US.swap(0, Ordering::Relaxed);
1518            let silu_n = MOE_SILU_CALLS.swap(0, Ordering::Relaxed);
1519            let dn_us = MOE_GEMV_DOWN_US.swap(0, Ordering::Relaxed);
1520            let dn_n = MOE_GEMV_DOWN_CALLS.swap(0, Ordering::Relaxed);
1521            let sa_us = MOE_SCALED_ADD_US.swap(0, Ordering::Relaxed);
1522            let sa_n = MOE_SCALED_ADD_CALLS.swap(0, Ordering::Relaxed);
1523            let cp_us = MOE_COPY_US.swap(0, Ordering::Relaxed);
1524            let cp_n = MOE_COPY_CALLS.swap(0, Ordering::Relaxed);
1525            eprintln!(
1526                "[moe-prof] decode total={} ms | sync={} ms ({}x) | host_topk={} ms ({}x) | gate_up={} ms ({}x) | silu={} ms ({}x) | down={} ms ({}x) | scaled_add={} ms ({}x) | copy={} ms ({}x)",
1527                total_us / 1000,
1528                sync_us / 1000, sync_n,
1529                topk_us / 1000, topk_n,
1530                gu_us / 1000, gu_n,
1531                silu_us / 1000, silu_n,
1532                dn_us / 1000, dn_n,
1533                sa_us / 1000, sa_n,
1534                cp_us / 1000, cp_n,
1535            );
1536        }
1537
1538        B::to_vec(&self.scratch.logits, vocab)
1539    }
1540
1541    /// Multi-sequence batched decode (Phase 4b for MoE).
1542    ///
1543    /// Mirrors `LlamaFamilyModel::decode_batch_internal` but adapted to
1544    /// the MoE forward. The wins come from running the GEMM-heavy ops
1545    /// (qkv_proj, o_proj, router, MoE expert mul_mm_id, lm_head) at
1546    /// m=M, even though attention stays a per-item loop because
1547    /// Qwen3-MoE uses contiguous KV — no paged path here.
1548    ///
1549    /// Cross-layer rms_norm fusion (the `weighted_sum_residual_norm_stacked`
1550    /// fast path) is disabled in batched mode: the prefill MoE path
1551    /// (`moe_forward_batched_prefill_impl`) writes to `moe_out` and we
1552    /// add to residual explicitly. Each layer's leading rms_norm runs
1553    /// at m=M, which is one fused dispatch on M rows — cheap.
1554    pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1555        let m = batch.len();
1556        if m == 0 {
1557            return Vec::new();
1558        }
1559        if m == 1 {
1560            let (cid, tok, pos) = &batch[0];
1561            return vec![self.decode_internal(cid, *tok, *pos)];
1562        }
1563
1564        let prof_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1565            Some(std::time::Instant::now())
1566        } else {
1567            None
1568        };
1569
1570        for (cid, _, _) in batch {
1571            self.ensure_kv(cid);
1572        }
1573        self.ensure_scratch(m);
1574        self.scratch.enable_batched_decode_scratch(&self.cfg);
1575
1576        let h = self.cfg.base.hidden_size;
1577        let vocab = self.cfg.base.vocab_size;
1578        let mut ctx = B::new_context();
1579
1580        // 0. Embed all M tokens into residual [M, H]
1581        let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1582        let mut residual = self
1583            .scratch
1584            .residual
1585            .take()
1586            .expect("scratch residual missing (previous call didn't restore)");
1587        B::embedding_lookup(&mut ctx, &self.embed, &tokens, &mut residual, h);
1588
1589        // 1..num_layers: batched forward for each layer
1590        for li in 0..self.cfg.base.num_layers {
1591            self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m)
1592                .expect("forward_layer_batched_decode");
1593        }
1594
1595        // Final RMSNorm on [M, H] → norm_out [M, H]
1596        B::rms_norm(
1597            &mut ctx,
1598            &residual,
1599            &self.final_norm_w,
1600            self.cfg.base.rms_norm_eps,
1601            &mut self.scratch.norm_out,
1602            m,
1603            h,
1604        );
1605
1606        // LM head with m=M → batch_logits [M, vocab]
1607        self.lm_head.forward(
1608            &mut ctx,
1609            &self.scratch.norm_out,
1610            &mut self.scratch.batch_logits,
1611            m,
1612        );
1613
1614        B::sync(&mut ctx);
1615        self.scratch.residual = Some(residual);
1616
1617        let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
1618
1619        // Profile dump (one decode_batch_internal call = one decode step
1620        // covering all m tokens).
1621        if let Some(t0) = prof_t0 {
1622            use std::sync::atomic::Ordering;
1623            let total_us = t0.elapsed().as_micros() as u64;
1624            let dense = BD_DENSE_US.swap(0, Ordering::Relaxed);
1625            let attn = BD_ATTN_PERITEM_US.swap(0, Ordering::Relaxed);
1626            let moe = BD_MOE_US.swap(0, Ordering::Relaxed);
1627            let layers = BD_LAYER_CALLS.swap(0, Ordering::Relaxed);
1628            let other = total_us.saturating_sub(dense + attn + moe);
1629            let pct = |us: u64| -> f64 {
1630                if total_us == 0 {
1631                    0.0
1632                } else {
1633                    100.0 * us as f64 / total_us as f64
1634                }
1635            };
1636            // MoE sub-stage breakdown — meaningful when
1637            // moe_forward_batched_decode_impl was used.
1638            let moe_route = MOE_BATCHED_DECODE_ROUTE_US.swap(0, Ordering::Relaxed);
1639            let moe_gate = MOE_BATCHED_DECODE_GATE_US.swap(0, Ordering::Relaxed);
1640            let moe_up = MOE_BATCHED_DECODE_UP_US.swap(0, Ordering::Relaxed);
1641            let moe_silu = MOE_BATCHED_DECODE_SILU_US.swap(0, Ordering::Relaxed);
1642            let moe_down = MOE_BATCHED_DECODE_DOWN_US.swap(0, Ordering::Relaxed);
1643            let moe_wsum = MOE_BATCHED_DECODE_WSUM_US.swap(0, Ordering::Relaxed);
1644            eprintln!(
1645                "[batched-decode-prof] m={} layers={} total={} ms | dense={} ({:.1}%) | attn_peritem={} ({:.1}%) | moe={} ({:.1}%) [route={} gate={} up={} silu={} down={} wsum={}] | other={} ({:.1}%)",
1646                m, layers, total_us / 1000,
1647                dense / 1000, pct(dense),
1648                attn / 1000, pct(attn),
1649                moe / 1000, pct(moe),
1650                moe_route / 1000, moe_gate / 1000, moe_up / 1000,
1651                moe_silu / 1000, moe_down / 1000, moe_wsum / 1000,
1652                other / 1000, pct(other),
1653            );
1654        }
1655
1656        (0..m)
1657            .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
1658            .collect()
1659    }
1660
1661    /// One transformer layer over M items: GEMMs at m=M, per-item
1662    /// attention loop, MoE FFN at m=M via the prefill batched path.
1663    /// Mirrors `LlamaFamilyModel::forward_layer_batched_decode` minus
1664    /// the paged branch.
1665    fn forward_layer_batched_decode(
1666        &mut self,
1667        ctx: &mut B::Context,
1668        li: usize,
1669        batch: &[(String, u32, u32)],
1670        residual: &mut B::Buffer,
1671        m: usize,
1672    ) -> Result<()> {
1673        let cfg_base = &self.cfg.base;
1674        let h = cfg_base.hidden_size;
1675        let nh = cfg_base.num_heads;
1676        let nkv = cfg_base.num_kv_heads;
1677        let hd = cfg_base.head_dim;
1678        let eps = cfg_base.rms_norm_eps;
1679        let q_dim = nh * hd;
1680        let kv_dim = nkv * hd;
1681
1682        let attn_layer = &self.attn_layers[li];
1683        let qk_mode: i32 = if cfg_base.has_qk_norm { 1 } else { 2 };
1684        let dummy_w = &attn_layer.input_ln_w;
1685        let q_norm_w = attn_layer.q_norm_w.as_ref().unwrap_or(dummy_w);
1686        let k_norm_w = attn_layer.k_norm_w.as_ref().unwrap_or(dummy_w);
1687
1688        let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
1689        let stage_t0 = || -> Option<std::time::Instant> {
1690            if prof {
1691                Some(std::time::Instant::now())
1692            } else {
1693                None
1694            }
1695        };
1696        let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
1697            if let Some(t) = t0 {
1698                B::sync(ctx);
1699                c.fetch_add(
1700                    t.elapsed().as_micros() as u64,
1701                    std::sync::atomic::Ordering::Relaxed,
1702                );
1703            }
1704        };
1705        if prof {
1706            BD_LAYER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1707        }
1708
1709        let dense_t0 = stage_t0();
1710
1711        // 1. rms_norm [M, H] → norm_out
1712        B::rms_norm(
1713            ctx,
1714            residual,
1715            &attn_layer.input_ln_w,
1716            eps,
1717            &mut self.scratch.norm_out,
1718            m,
1719            h,
1720        );
1721
1722        // 2. qkv_proj GEMM at m=M: norm_out [M, H] → qkv_out [M, QKV]
1723        attn_layer
1724            .qkv_proj
1725            .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
1726
1727        // ── Paged batched attention path ───────────────────────────────
1728        //
1729        // Mirrors LlamaFamilyModel's Phase 4b paged batched-decode. When
1730        // `FERRUM_METAL_PAGED_KV=1` was set at ensure_kv time, each
1731        // cache_id has paged metadata (block_table + context_lens) and
1732        // K/V live in the shared `paged_pools[layer]` pool. This path:
1733        //   1. m × `split_qkv_norm_rope_into_paged_cache` writes K/V into
1734        //      the pool at each item's allocated blocks AND fills
1735        //      `paged_batch_q[i*q_dim ..]` with that item's head-major Q.
1736        //   2. Build `paged_batch_block_tables [m, max_blocks_per_seq]`
1737        //      and `paged_batch_context_lens [m]` host-side, upload.
1738        //   3. ONE `paged_decode_attention(num_seqs=m)` call reads all m
1739        //      sequences' K/V from the pool via per-seq block_tables,
1740        //      writes outputs to `paged_batch_o [m, q_dim]`.
1741        //   4. Per-item copy_slice paged_batch_o[i] → attn_flat[i*q_dim].
1742        //
1743        // This is the structural fix for the c=16 attn_peritem cliff
1744        // (~55 ms / round of 16 sequential m=1 flash_attn + plumbing).
1745        let is_paged = self.paged_pools.is_some();
1746        if is_paged {
1747            stage_end(dense_t0, ctx, &BD_DENSE_US);
1748            let attn_t0 = stage_t0();
1749
1750            let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
1751            let block_size = 16; // matches PAGED_BLOCK_SIZE in ensure_kv
1752            let qkv_stride = q_dim + 2 * kv_dim;
1753
1754            // Step 1: per-item paged write. Read each item's qkv out of
1755            // the batched qkv_out buffer; write its head-major Q into
1756            // paged_batch_q[i*q_dim ..]; write its K/V into the pool at
1757            // its allocated blocks via block_table.
1758            let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
1759            let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
1760            let pool_ptr = {
1761                let pools = self.paged_pools.as_mut().unwrap();
1762                (
1763                    &mut pools[li].0 as *mut B::Buffer,
1764                    &mut pools[li].1 as *mut B::Buffer,
1765                )
1766            };
1767            // SAFETY: pools allocated-once, see paged_pools field comment.
1768            let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
1769            for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1770                let pos_i = *pos as usize;
1771                let caches = self
1772                    .kv_caches
1773                    .get(cache_id)
1774                    .expect("paged batched: cache not present");
1775                let cache = &caches[li];
1776                let bt = cache
1777                    .block_table
1778                    .as_ref()
1779                    .expect("paged batched: block_table missing");
1780                let cache_len_before = cache.len;
1781                let bt_ptr = bt as *const B::Buffer;
1782                // SAFETY: bt is read-only during the dispatch; we don't
1783                // touch self.kv_caches between this raw deref and the
1784                // call below.
1785                let bt_safe: &B::Buffer = unsafe { &*bt_ptr };
1786                B::split_qkv_norm_rope_into_paged_cache(
1787                    ctx,
1788                    &self.scratch.qkv_out,
1789                    (i as u64) * qkv_stride_bytes,
1790                    q_norm_w,
1791                    k_norm_w,
1792                    &self.rope.cos,
1793                    &self.rope.sin,
1794                    self.scratch
1795                        .paged_batch_q
1796                        .as_mut()
1797                        .expect("paged_batch_q missing"),
1798                    (i as u64) * q_head_major_size_bytes,
1799                    pool_k,
1800                    pool_v,
1801                    bt_safe,
1802                    1,
1803                    nh,
1804                    nkv,
1805                    hd,
1806                    pos_i,
1807                    eps,
1808                    qk_mode,
1809                    cache_len_before,
1810                    block_size,
1811                    max_blocks_per_seq,
1812                )
1813                .expect("split_qkv_norm_rope_into_paged_cache (batched)");
1814            }
1815
1816            // Step 2: bump cache.len and stack block_tables + context_lens
1817            // host-side, then upload to device scratch.
1818            let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
1819            let mut stacked_cl: Vec<u32> = vec![0u32; m];
1820            for (i, (cache_id, _, _)) in batch.iter().enumerate() {
1821                let caches = self
1822                    .kv_caches
1823                    .get_mut(cache_id)
1824                    .expect("paged batched: cache not present");
1825                let cache = &mut caches[li];
1826                cache.len += 1;
1827                let len = cache.len as u32;
1828                stacked_cl[i] = len;
1829                let blocks = &cache.paged_block_indices;
1830                let n_to_copy = blocks.len().min(max_blocks_per_seq);
1831                stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
1832                    .copy_from_slice(&blocks[..n_to_copy]);
1833            }
1834            let bt_buf = self
1835                .scratch
1836                .paged_batch_block_tables
1837                .as_mut()
1838                .expect("paged_batch_block_tables missing");
1839            B::write_u32(ctx, bt_buf, &stacked_bt);
1840            let cl_buf = self
1841                .scratch
1842                .paged_batch_context_lens
1843                .as_mut()
1844                .expect("paged_batch_context_lens missing");
1845            B::write_u32(ctx, cl_buf, &stacked_cl);
1846
1847            // Step 3: one batched paged_decode_attention(num_seqs=m).
1848            let bt_ptr =
1849                self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
1850            let cl_ptr =
1851                self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
1852            let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
1853            let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
1854            // SAFETY: scratch buffers are not aliased; we hold &mut self
1855            // through this entire block.
1856            let bt_safe = unsafe { &*bt_ptr };
1857            let cl_safe = unsafe { &*cl_ptr };
1858            let q_safe = unsafe { &*q_ptr };
1859            let o_safe = unsafe { &mut *o_ptr };
1860            B::paged_decode_attention(
1861                ctx,
1862                q_safe,
1863                pool_k,
1864                pool_v,
1865                o_safe,
1866                bt_safe,
1867                cl_safe,
1868                m,
1869                nh,
1870                nkv,
1871                hd,
1872                block_size,
1873                max_blocks_per_seq,
1874                1, // q_len
1875            )
1876            .expect("paged batched decode");
1877
1878            // Step 4: per-item copy paged_batch_o[i] → attn_flat[i*q_dim].
1879            for i in 0..m {
1880                B::copy_slice(
1881                    ctx,
1882                    self.scratch.paged_batch_o.as_ref().unwrap(),
1883                    i * q_dim,
1884                    &mut self.scratch.attn_flat,
1885                    i * q_dim,
1886                    q_dim,
1887                );
1888            }
1889
1890            stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
1891        } else {
1892            // 3. split_qkv [M, QKV] → q_buf [M, Q], k_buf [M, KV], v_buf [M, KV]
1893            B::split_qkv(
1894                ctx,
1895                &self.scratch.qkv_out,
1896                &mut self.scratch.q_buf,
1897                &mut self.scratch.k_buf,
1898                &mut self.scratch.v_buf,
1899                m,
1900                q_dim,
1901                kv_dim,
1902            );
1903
1904            // 4-6. Per-item loop: rope + kv_append + attention.
1905            //      Each item has its own cache_id + pos + kv_len.
1906            let q_single = self
1907                .scratch
1908                .q_single
1909                .as_ref()
1910                .expect("q_single missing — enable_batched_decode_scratch not called")
1911                as *const B::Buffer;
1912            let k_single =
1913                self.scratch.k_single.as_ref().expect("k_single missing") as *const B::Buffer;
1914            let v_single =
1915                self.scratch.v_single.as_ref().expect("v_single missing") as *const B::Buffer;
1916            let q_hm_single =
1917                self.scratch
1918                    .q_head_major_single
1919                    .as_mut()
1920                    .expect("q_head_major_single missing") as *mut B::Buffer;
1921            let k_hm_single =
1922                self.scratch
1923                    .k_head_major_single
1924                    .as_mut()
1925                    .expect("k_head_major_single missing") as *mut B::Buffer;
1926            let v_hm_single =
1927                self.scratch
1928                    .v_head_major_single
1929                    .as_mut()
1930                    .expect("v_head_major_single missing") as *mut B::Buffer;
1931            let attn_hm_single =
1932                self.scratch
1933                    .attn_head_major_single
1934                    .as_mut()
1935                    .expect("attn_head_major_single missing") as *mut B::Buffer;
1936            // SAFETY: each Option holds a stable B::Buffer; we don't mutate
1937            // self.scratch in a way that would invalidate them inside the loop
1938            // (the kv_caches mutation is on a disjoint field).
1939
1940            // End of dense block (rms_norm + qkv_proj + split_qkv); start
1941            // per-item attention loop instrumentation.
1942            stage_end(dense_t0, ctx, &BD_DENSE_US);
1943            let attn_t0 = stage_t0();
1944
1945            for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1946                let pos_i = *pos as usize;
1947
1948                // SAFETY: borrows of disjoint scratch fields, see above.
1949                let q_single_ref = unsafe { &*q_single };
1950                let k_single_ref = unsafe { &*k_single };
1951                let v_single_ref = unsafe { &*v_single };
1952                let q_hm_single_mut = unsafe { &mut *q_hm_single };
1953                let k_hm_single_mut = unsafe { &mut *k_hm_single };
1954                let v_hm_single_mut = unsafe { &mut *v_hm_single };
1955                let attn_hm_single_mut = unsafe { &mut *attn_hm_single };
1956
1957                // Extract item i's Q/K/V slice from the batched buffers.
1958                B::copy_slice(
1959                    ctx,
1960                    &self.scratch.q_buf,
1961                    i * q_dim,
1962                    // copy_slice signature wants &mut for dst, but q_single
1963                    // is shared; we need a *mut variant — since enable_*
1964                    // gives us Option, we can do as_mut() here.
1965                    self.scratch.q_single.as_mut().unwrap(),
1966                    0,
1967                    q_dim,
1968                );
1969                B::copy_slice(
1970                    ctx,
1971                    &self.scratch.k_buf,
1972                    i * kv_dim,
1973                    self.scratch.k_single.as_mut().unwrap(),
1974                    0,
1975                    kv_dim,
1976                );
1977                B::copy_slice(
1978                    ctx,
1979                    &self.scratch.v_buf,
1980                    i * kv_dim,
1981                    self.scratch.v_single.as_mut().unwrap(),
1982                    0,
1983                    kv_dim,
1984                );
1985
1986                // qk_norm_rope with tokens=1, per-item pos.
1987                B::qk_norm_rope(
1988                    ctx,
1989                    q_single_ref,
1990                    q_norm_w,
1991                    &self.rope.cos,
1992                    &self.rope.sin,
1993                    q_hm_single_mut,
1994                    1,
1995                    nh,
1996                    hd,
1997                    pos_i,
1998                    eps,
1999                    qk_mode,
2000                );
2001                B::qk_norm_rope(
2002                    ctx,
2003                    k_single_ref,
2004                    k_norm_w,
2005                    &self.rope.cos,
2006                    &self.rope.sin,
2007                    k_hm_single_mut,
2008                    1,
2009                    nkv,
2010                    hd,
2011                    pos_i,
2012                    eps,
2013                    qk_mode,
2014                );
2015                B::qk_norm_rope(
2016                    ctx,
2017                    v_single_ref,
2018                    dummy_w,
2019                    &self.rope.cos,
2020                    &self.rope.sin,
2021                    v_hm_single_mut,
2022                    1,
2023                    nkv,
2024                    hd,
2025                    pos_i,
2026                    eps,
2027                    0,
2028                );
2029
2030                // KV append + attention for item i's cache.
2031                let caches = self
2032                    .kv_caches
2033                    .get_mut(cache_id)
2034                    .expect("ensure_kv must be called before forward_layer_batched");
2035                let cache = &mut caches[li];
2036                B::kv_cache_append_head_major(
2037                    ctx,
2038                    &mut cache.k,
2039                    &mut cache.v,
2040                    cache.len,
2041                    cache.capacity,
2042                    k_hm_single_mut,
2043                    v_hm_single_mut,
2044                    1,
2045                    nkv,
2046                    hd,
2047                );
2048                cache.len += 1;
2049                let kv_len = cache.len;
2050                let kv_stride = cache.capacity;
2051
2052                let attn_cfg = ferrum_kernels::backend::AttnConfig {
2053                    num_heads: nh,
2054                    num_kv_heads: nkv,
2055                    head_dim: hd,
2056                    causal: true,
2057                    scale: 1.0 / (hd as f32).sqrt(),
2058                    kv_seq_stride: kv_stride,
2059                    sliding_window: cfg_base.sliding_window,
2060                };
2061                B::flash_attention(
2062                    ctx,
2063                    q_hm_single_mut,
2064                    &cache.k,
2065                    &cache.v,
2066                    attn_hm_single_mut,
2067                    1,
2068                    1,
2069                    kv_len,
2070                    pos_i,
2071                    &attn_cfg,
2072                );
2073
2074                // Untranspose head-major → token-major: for tokens=1 the
2075                // layouts are byte-identical, so copy_slice straight into
2076                // attn_flat at the per-item offset (saves a transpose).
2077                B::copy_slice(
2078                    ctx,
2079                    attn_hm_single_mut,
2080                    0,
2081                    &mut self.scratch.attn_flat,
2082                    i * q_dim,
2083                    q_dim,
2084                );
2085            }
2086
2087            // End of per-item attention loop.
2088            stage_end(attn_t0, ctx, &BD_ATTN_PERITEM_US);
2089        } // end of `else` for non-paged path
2090
2091        let post_attn_t0 = stage_t0();
2092
2093        // 7. o_proj GEMM at m=M: attn_flat [M, Q] → o_proj_out [M, H]
2094        attn_layer.o_proj.forward(
2095            ctx,
2096            &self.scratch.attn_flat,
2097            &mut self.scratch.o_proj_out,
2098            m,
2099        );
2100
2101        // 8. fused residual_add + post_attention_layernorm
2102        B::fused_add_rms_norm(
2103            ctx,
2104            residual,
2105            &self.scratch.o_proj_out,
2106            &attn_layer.post_ln_w,
2107            eps,
2108            &mut self.scratch.norm_out,
2109            m,
2110            h,
2111        );
2112
2113        // o_proj + post-norm count under DENSE.
2114        stage_end(post_attn_t0, ctx, &BD_DENSE_US);
2115        let moe_t0 = stage_t0();
2116
2117        // 9. Router gemv: norm_out [M, H] → router_logits [M, n_exp]
2118        let moe_layer = &self.moe_layers[li];
2119        moe_layer.router.forward(
2120            ctx,
2121            &self.scratch.norm_out,
2122            &mut self.scratch.router_logits,
2123            m,
2124        );
2125
2126        // 10. MoE expert dispatch — per-item loop using the cheap
2127        //     stacked decode kernels (gemv_quant_moe_id + silu_mul_stacked
2128        //     + weighted_sum_batched). NOT the batched prefill path:
2129        //     `moe_forward_batched_prefill_impl` is tuned for large M
2130        //     (prefill) and the GPU bucketing overhead
2131        //     (`compute_ids_tpe_gpu` + indirect-dispatch arg-buffer
2132        //     setup) costs more than M sequential gemv calls at small M.
2133        //
2134        // Strategy: route ALL M tokens once via batched
2135        // `route_topk_softmax`, then loop M iterations of the stacked
2136        // decode kernels. Each iteration:
2137        //   - extract item i's selected ids + weights from the M-batch
2138        //     buffers via copy_slice
2139        //   - copy norm_out[i*h..(i+1)*h] → x_single
2140        //   - 3× gemv_quant_moe_id (gate/up/down) reading from x_single
2141        //   - silu_mul_stacked
2142        //   - weighted_sum_batched(batch=1) → acc_buf  (fresh write,
2143        //     no residual fusion)
2144        //   - copy_slice acc_buf → moe_out[i*h..(i+1)*h]
2145        // After the loop, single add_inplace residual += moe_out [M, H].
2146        let stacked_path_available = moe_layer.experts.gate_stacked.is_some()
2147            && moe_layer.experts.up_stacked.is_some()
2148            && moe_layer.experts.down_stacked.is_some();
2149        // MoE FFN dispatch tiers (m = batch size of this layer call):
2150        //
2151        //   m = 1          : `moe_forward_stacked_decode_impl`
2152        //                    (decode m=1 fast path, fused gate+up+silu)
2153        //   m ≥ 8 (default): `moe_forward_batched_prefill_impl`
2154        //                    (GEMM with simdgroup_matmul + GPU bucketing)
2155        //   else (m=2..7)  : per-item stacked decode loop
2156        //
2157        // EXPERIMENTAL — opt-in `FERRUM_MOE_BATCHED_DECODE=1` engages the
2158        // new `moe_forward_batched_decode_impl` for 2 ≤ m < 32. The
2159        // kernel itself is bitwise correct and ports llama.cpp's
2160        // `kernel_mul_mv_id` strategy to ferrum (one indirect-dispatch
2161        // GEMV per linear covering all m*top_k pairs). Empirically OFF
2162        // by default because the existing `forward_layer_batched_decode`
2163        // attention plumbing (per-item copy_slice × m × 6 dispatches)
2164        // scales linearly with m and overshadows the FFN savings —
2165        // regression measured at -19% (c=4) and -36% (c=16) on
2166        // Qwen3-30B-A3B Q4_K_M / M1 Max. Closing that gap requires a
2167        // batched attention path with offset-aware QKV slicing, which
2168        // is the next PR's job. Until then the kernel sits as
2169        // infrastructure.
2170        // Two independent thresholds:
2171        //   * `FERRUM_MOE_BATCH_THRESHOLD` (default 8) — m above which
2172        //     the LEGACY non-experimental path uses the prefill GEMM.
2173        //     Shared with `decode_batch`'s engine-level gate, so users
2174        //     who set it to a small value to engage batched decode
2175        //     don't accidentally also push the inner FFN to GEMM.
2176        //   * `FERRUM_MOE_PREFILL_THRESHOLD` (default 32) — m above
2177        //     which the EXPERIMENTAL batched-decode path defers to the
2178        //     prefill GEMM path. Mirrors llama.cpp's `ne21_mm_id_min=32`
2179        //     GEMV→GEMM boundary.
2180        let legacy_prefill_threshold: usize = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
2181            .ok()
2182            .and_then(|s| s.parse().ok())
2183            .unwrap_or(8);
2184        let new_prefill_threshold: usize = std::env::var("FERRUM_MOE_PREFILL_THRESHOLD")
2185            .ok()
2186            .and_then(|s| s.parse().ok())
2187            .unwrap_or(32);
2188        // 0.7.2: default to ON when paged-KV is also on (which is now
2189        // the default for Metal). The historical regression for this
2190        // flag (-19% c=4 / -36% c=16) was measured in the pre-paged-KV
2191        // world where `forward_layer_batched_decode`'s per-item
2192        // copy_slice × m × 6 attention dispatches cost more than the
2193        // batched MoE FFN saved. Once paged-KV is on, attention runs as
2194        // one `paged_decode_attention(num_seqs=m)` dispatch, the
2195        // plumbing cost drops, and the batched MoE GEMV's win net out
2196        // to ~+50% at c=16. `FERRUM_MOE_BATCHED_DECODE=0` forces off.
2197        let new_batched_default = stacked_path_available && B::supports_batched_moe_gemv();
2198        let new_batched_enabled = new_batched_default
2199            && std::env::var("FERRUM_MOE_BATCHED_DECODE")
2200                .map(|v| v != "0")
2201                .unwrap_or(true);
2202
2203        // When the new path is opted in, it owns the m=2..new_prefill_threshold
2204        // range; the legacy threshold is overridden upward.
2205        let use_prefill_batched = if new_batched_enabled {
2206            stacked_path_available && m >= new_prefill_threshold
2207        } else {
2208            stacked_path_available && m >= legacy_prefill_threshold
2209        };
2210        let use_batched_decode = new_batched_enabled && !use_prefill_batched && m >= 2;
2211
2212        if use_prefill_batched {
2213            moe_forward_batched_prefill_impl::<B>(
2214                ctx,
2215                moe_layer,
2216                &mut self.scratch,
2217                h,
2218                self.cfg.expert_intermediate_size,
2219                self.cfg.num_experts_per_tok,
2220                self.cfg.num_experts,
2221                self.cfg.norm_topk_prob,
2222                m,
2223            )?;
2224        } else if use_batched_decode {
2225            moe_forward_batched_decode_impl::<B>(
2226                ctx,
2227                moe_layer,
2228                &mut self.scratch,
2229                h,
2230                self.cfg.expert_intermediate_size,
2231                self.cfg.num_experts_per_tok,
2232                self.cfg.num_experts,
2233                self.cfg.norm_topk_prob,
2234                m,
2235            )?;
2236        } else if stacked_path_available {
2237            let inter = self.cfg.expert_intermediate_size;
2238            let top_k = self.cfg.num_experts_per_tok;
2239            let n_exp = self.cfg.num_experts;
2240            let norm_topk_prob = self.cfg.norm_topk_prob;
2241            let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2242            let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2243            let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2244
2245            // Single batched router pass: writes selected_ids_buf [M, top_k]
2246            // and weights_2d [M, top_k]. Replaces M individual route calls.
2247            B::route_topk_softmax(
2248                ctx,
2249                &self.scratch.router_logits,
2250                &mut self.scratch.selected_ids_buf,
2251                &mut self.scratch.weights_2d,
2252                m,
2253                n_exp,
2254                top_k,
2255                norm_topk_prob,
2256            )?;
2257
2258            // Per-item loop using offset-aware kernel APIs — eliminates
2259            // the 4 copy_slice round-trips per iteration that the
2260            // earlier implementation needed (ids, weights, x_single,
2261            // moe_out). At c=16 / 48 layers that's ~3,072 dispatches
2262            // saved per token. Uses `gemv_quant_moe_id_offset` to read
2263            // `selected_ids_buf` at the i-th `top_k` block and
2264            // `norm_out` at the i-th hidden row directly. Falls back
2265            // to copy_slice path if backend doesn't support offsets.
2266            for i in 0..m {
2267                let ids_offset = i * top_k;
2268                let activation_offset = i * h;
2269                let weights_offset = i * top_k;
2270                let moe_out_offset = i * h;
2271
2272                // Stacked gate / up gemvs — broadcast item i's row of
2273                // norm_out across top_k slots, read item i's ids.
2274                let gate_res = B::gemv_quant_moe_id_offset(
2275                    ctx,
2276                    &self.scratch.norm_out,
2277                    activation_offset,
2278                    gate_stacked,
2279                    &self.scratch.selected_ids_buf,
2280                    ids_offset,
2281                    &mut self.scratch.gate_out_stacked,
2282                    top_k,
2283                    0,
2284                );
2285                if gate_res.is_err() {
2286                    // Backend doesn't support offset variants — fall back
2287                    // to the legacy copy_slice path. Same as before.
2288                    B::copy_slice(
2289                        ctx,
2290                        &self.scratch.selected_ids_buf,
2291                        ids_offset,
2292                        &mut self.scratch.ids_buf,
2293                        0,
2294                        top_k,
2295                    );
2296                    B::copy_slice(
2297                        ctx,
2298                        &self.scratch.weights_2d,
2299                        weights_offset,
2300                        &mut self.scratch.weights_buf,
2301                        0,
2302                        top_k,
2303                    );
2304                    B::copy_slice(
2305                        ctx,
2306                        &self.scratch.norm_out,
2307                        activation_offset,
2308                        &mut self.scratch.x_single,
2309                        0,
2310                        h,
2311                    );
2312                    B::gemv_quant_moe_id(
2313                        ctx,
2314                        &self.scratch.x_single,
2315                        gate_stacked,
2316                        &self.scratch.ids_buf,
2317                        &mut self.scratch.gate_out_stacked,
2318                        top_k,
2319                        0,
2320                    )?;
2321                    B::gemv_quant_moe_id(
2322                        ctx,
2323                        &self.scratch.x_single,
2324                        up_stacked,
2325                        &self.scratch.ids_buf,
2326                        &mut self.scratch.up_out_stacked,
2327                        top_k,
2328                        0,
2329                    )?;
2330                    B::silu_mul_stacked(
2331                        ctx,
2332                        &self.scratch.gate_out_stacked,
2333                        &self.scratch.up_out_stacked,
2334                        &mut self.scratch.silu_stacked,
2335                        top_k,
2336                        inter,
2337                    )?;
2338                    B::gemv_quant_moe_id(
2339                        ctx,
2340                        &self.scratch.silu_stacked,
2341                        down_stacked,
2342                        &self.scratch.ids_buf,
2343                        &mut self.scratch.down_out_stacked,
2344                        top_k,
2345                        inter,
2346                    )?;
2347                    B::weighted_sum_batched(
2348                        ctx,
2349                        &self.scratch.down_out_stacked,
2350                        &self.scratch.weights_buf,
2351                        &mut self.scratch.acc_buf,
2352                        1,
2353                        top_k,
2354                        h,
2355                    )?;
2356                    B::copy_slice(
2357                        ctx,
2358                        &self.scratch.acc_buf,
2359                        0,
2360                        &mut self.scratch.moe_out,
2361                        moe_out_offset,
2362                        h,
2363                    );
2364                    continue;
2365                }
2366                // Fast path: offset-aware all the way through.
2367                B::gemv_quant_moe_id_offset(
2368                    ctx,
2369                    &self.scratch.norm_out,
2370                    activation_offset,
2371                    up_stacked,
2372                    &self.scratch.selected_ids_buf,
2373                    ids_offset,
2374                    &mut self.scratch.up_out_stacked,
2375                    top_k,
2376                    0,
2377                )?;
2378                B::silu_mul_stacked(
2379                    ctx,
2380                    &self.scratch.gate_out_stacked,
2381                    &self.scratch.up_out_stacked,
2382                    &mut self.scratch.silu_stacked,
2383                    top_k,
2384                    inter,
2385                )?;
2386                B::gemv_quant_moe_id_offset(
2387                    ctx,
2388                    &self.scratch.silu_stacked,
2389                    0, // silu_stacked itself stays at offset 0 each iter
2390                    down_stacked,
2391                    &self.scratch.selected_ids_buf,
2392                    ids_offset,
2393                    &mut self.scratch.down_out_stacked,
2394                    top_k,
2395                    inter,
2396                )?;
2397                // Write directly into moe_out at the per-item offset —
2398                // skips the copy_slice from acc_buf.
2399                B::weighted_sum_batched_offset(
2400                    ctx,
2401                    &self.scratch.down_out_stacked,
2402                    &self.scratch.weights_2d,
2403                    weights_offset,
2404                    &mut self.scratch.moe_out,
2405                    moe_out_offset,
2406                    1,
2407                    top_k,
2408                    h,
2409                )?;
2410            }
2411        } else {
2412            // Backend without stacked variants — fall back to the legacy
2413            // per-(token, expert) host-routed path.
2414            moe_forward::<B>(
2415                ctx,
2416                &self.scratch.norm_out,
2417                &self.scratch.router_logits,
2418                &mut self.scratch.moe_out,
2419                m,
2420                h,
2421                self.cfg.expert_intermediate_size,
2422                self.cfg.num_experts,
2423                self.cfg.num_experts_per_tok,
2424                self.cfg.norm_topk_prob,
2425                &moe_layer.experts,
2426                &mut self.scratch.x_single,
2427                &mut self.scratch.acc_buf,
2428                &mut self.scratch.gate_up_buf,
2429                &mut self.scratch.silu_buf,
2430                &mut self.scratch.down_buf,
2431                &self.scratch.zero_hidden,
2432            )?;
2433        }
2434
2435        // 11. residual += moe_out [M, H]
2436        B::add_inplace(ctx, residual, &self.scratch.moe_out, m * h);
2437
2438        // Close MoE-block instrumentation (router + FFN + residual add).
2439        stage_end(moe_t0, ctx, &BD_MOE_US);
2440
2441        Ok(())
2442    }
2443}
2444
2445impl<B: Backend> DecoderOnlyLLM for Qwen3MoeModel<B> {
2446    fn config(&self) -> &LlmRuntimeConfig {
2447        &self.runtime_cfg
2448    }
2449
2450    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2451        // Eager scratch + KV cache grow + a 1-token forward warmup so
2452        // the first real prefill / decode doesn't pay the cold-start
2453        // ~25-MTLBuffer scratch alloc + ~96-MTLBuffer KV alloc + Metal
2454        // pipeline-state first-bind costs (~265 ms total on Qwen3-MoE
2455        // 30B-A3B / M1 Max). Mirrors what llama-bench's --warmup does
2456        // (which runs a same-shape forward before the timer).
2457        self.ensure_scratch(max_tokens);
2458        self.ensure_kv(cache_id);
2459
2460        // Warmup forward through all 48 layers under a scratch cache_id
2461        // so the real `cache_id` starts at pos_offset=0. Token 0 is
2462        // valid for any tokenizer (BOS or pad).
2463        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2464        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2465        // Drop the warmup KV cache slot — real cache_id is unaffected.
2466        if let Some(caches) = self.kv_caches.remove(WARMUP_CACHE) {
2467            self.kv_free_pool.push(caches);
2468        }
2469    }
2470
2471    fn kv_capacity(&self) -> usize {
2472        // Mirror the bound `ensure_kv` will use when allocating the cache.
2473        let model_max = self.cfg.base.max_seq_len;
2474        const DEFAULT_KV_CAPACITY: usize = 512;
2475        std::env::var("FERRUM_KV_CAPACITY")
2476            .ok()
2477            .and_then(|s| s.parse::<usize>().ok())
2478            .map(|cap| cap.min(model_max))
2479            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
2480    }
2481
2482    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2483        self.prefill_internal(cache_id, tokens)
2484    }
2485
2486    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2487        self.decode_internal(cache_id, token, pos)
2488    }
2489
2490    // decode_batch is gated to use the batched path only when it's a
2491    // measurable win. The crossover depends on M:
2492    //
2493    //   - At low M (≤ ~8) the per-item `decode_internal` loop wins
2494    //     because: (a) it stays at scratch offset 0 (no copy_slice
2495    //     overhead), (b) it preserves the cross-layer rms_norm fusion
2496    //     fast path (`weighted_sum_residual_norm_stacked`).
2497    //   - At high M (≥ ~12) the batched path wins because the dense
2498    //     GEMM batching (qkv_proj, o_proj, router, lm_head at m=M) and
2499    //     the prefill-batched MoE dispatch (one `gemm_quant_moe_id` for
2500    //     all tokens) amortise the ~48-dispatch lost-fusion penalty.
2501    //
2502    // Default opted out of FERRUM_MOE_BATCHED. When opted in, the
2503    // batched path engages only at M ≥ FERRUM_MOE_BATCH_THRESHOLD
2504    // (default 12). Below that we still go per-item.
2505    //
2506    // Empirical note 2026-05-02: a follow-up PR added a batched MoE
2507    // GEMV kernel (`gemv_quant_moe_id_batched`) that holds MoE
2508    // dispatch count flat as concurrency scales. Wiring it through
2509    // `decode_batch_internal` regressed throughput by 19% (c=4) /
2510    // 36% (c=16) — `forward_layer_batched_decode`'s per-item
2511    // attention plumbing (copy_slice × m × 6 dispatches) costs more
2512    // than the MoE save. The batched MoE kernel is shipped as opt-in
2513    // infrastructure (`FERRUM_MOE_BATCHED_DECODE=1`); flipping it on
2514    // by default has to wait until the attention plumbing is fixed.
2515    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2516        let m = batch.len();
2517        // Default ON in 0.7.2+. The threshold (default 8) keeps small-m
2518        // requests on the per-token loop where it still wins on this
2519        // hardware — see docs/bench/macos-2026-05-02 for the crossover
2520        // measurements (c=4 batched 39 < per_token 42; c=8 batched 59 >
2521        // per_token 47). `FERRUM_MOE_BATCHED=0` forces the legacy loop.
2522        let opted_in = std::env::var("FERRUM_MOE_BATCHED")
2523            .map(|v| v != "0")
2524            .unwrap_or(true);
2525        let threshold = std::env::var("FERRUM_MOE_BATCH_THRESHOLD")
2526            .ok()
2527            .and_then(|s| s.parse::<usize>().ok())
2528            .unwrap_or(8);
2529        if opted_in && m >= threshold {
2530            self.decode_batch_internal(batch)
2531        } else {
2532            batch
2533                .iter()
2534                .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
2535                .collect()
2536        }
2537    }
2538
2539    fn release(&mut self, cache_id: &str) {
2540        let mut ctx = B::new_context();
2541        B::sync(&mut ctx);
2542        B::reset_graph(&mut ctx);
2543        B::sync(&mut ctx);
2544        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2545            // Paged mode: return the cache_id's blocks to the shared
2546            // allocator so other sequences can reuse them. Without this,
2547            // every request consumes max_blocks_per_seq blocks
2548            // permanently — pool exhausts after FERRUM_PAGED_MAX_SEQS
2549            // requests and subsequent ensure_kv panics with
2550            // "scratch residual missing" (the cascade panic from a
2551            // failed ensure_kv path leaving scratch poisoned).
2552            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2553                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2554                if let Some(c0) = caches.first() {
2555                    if !c0.paged_block_indices.is_empty() {
2556                        alloc.free(&c0.paged_block_indices);
2557                    }
2558                }
2559                for c in caches.iter_mut() {
2560                    c.paged_block_indices.clear();
2561                }
2562            }
2563            self.kv_free_pool.push(caches);
2564        }
2565    }
2566
2567    fn reset(&mut self) {
2568        let mut ctx = B::new_context();
2569        B::sync(&mut ctx);
2570        B::reset_graph(&mut ctx);
2571        B::sync(&mut ctx);
2572        self.kv_caches.clear();
2573        self.kv_free_pool.clear();
2574    }
2575}
2576
2577/// Batched MoE FFN — decode (m=1) and per-token-prefill (m>1 looped).
2578///
2579/// Three batched `gemv_quant_moe_id` dispatches per token: gate (broadcast
2580/// activation), up (broadcast activation), down (per-slot activation —
2581/// each expert sees its own silu·up). The per-(token, expert) outer loop
2582/// shrinks from `top_k * 4` dispatches per layer to **3 batched + 1
2583/// silu_mul_split + 1 weighted_sum_dispatch_loop**.
2584///
2585/// For prefill (m > 1) we loop over tokens externally — each token's
2586/// router output drives a single batched call. Still much faster than
2587/// the per-(token, expert) per-Linear path because the gemvs are batched.
2588///
2589/// Free function (not a method) so the caller can split the borrow on
2590/// `self` between `moe_layers[li]` (immutable) and `scratch` (mutable).
2591#[allow(clippy::too_many_arguments)]
2592fn moe_forward_stacked_decode_impl<B: Backend>(
2593    ctx: &mut B::Context,
2594    moe_layer: &Qwen3MoeLayerState<B>,
2595    scratch: &mut Qwen3MoeScratch<B>,
2596    h: usize,
2597    inter: usize,
2598    top_k: usize,
2599    n_exp: usize,
2600    norm_topk_prob: bool,
2601    tokens: usize,
2602    residual: &mut B::Buffer,
2603    // If `Some`, fold the NEXT layer's leading rms_norm into the
2604    // weighted-sum-residual tail using `weighted_sum_residual_norm_stacked`.
2605    next_norm_w: Option<&B::Buffer>,
2606    eps: f32,
2607) -> Result<()> {
2608    // GPU-side routing: one Metal launch reads router_logits and writes
2609    // selected ids + combine weights directly into device-side scratch
2610    // buffers. Eliminates the per-layer `B::sync + B::to_vec(router_logits)
2611    // + host route()` round trip — the dominant remaining cost in the
2612    // decode hot path (~10% of total decode latency).
2613    let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
2614    let stage_t0 = || -> Option<std::time::Instant> {
2615        if prof {
2616            Some(std::time::Instant::now())
2617        } else {
2618            None
2619        }
2620    };
2621    let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
2622        if let Some(t) = t0 {
2623            B::sync(ctx);
2624            c.fetch_add(
2625                t.elapsed().as_micros() as u64,
2626                std::sync::atomic::Ordering::Relaxed,
2627            );
2628        }
2629    };
2630
2631    let t0 = stage_t0();
2632    B::route_topk_softmax(
2633        ctx,
2634        &scratch.router_logits,
2635        &mut scratch.ids_buf,
2636        &mut scratch.weights_buf,
2637        tokens,
2638        n_exp,
2639        top_k,
2640        norm_topk_prob,
2641    )?;
2642    stage_end(t0, ctx, &DEC_ROUTE_US);
2643
2644    let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2645    let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2646    let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2647
2648    // moe_forward_stacked_decode_impl is only called when `tokens == 1`
2649    // (the branch in `forward_layer` routes prefill m>1 through
2650    // `moe_forward_batched_prefill_impl` instead). The for-b loop and
2651    // the copy norm_out[b*h] → x_single were vestigial scaffolding;
2652    // for tokens=1 norm_out[0..h] IS the activation row, and we can
2653    // pass it straight to the gemv kernel via src1_stride=0 broadcast.
2654    debug_assert_eq!(
2655        tokens, 1,
2656        "moe_forward_stacked_decode_impl expects tokens=1 (prefill goes through moe_forward_batched_prefill_impl)"
2657    );
2658    let _ = tokens; // silence unused-warning when assertion is compiled out
2659
2660    {
2661        // ids_buf and weights_buf populated by the GPU router above —
2662        // no host writes needed here in the decode path.
2663
2664        // Fused-vs-unfused gate+up+silu selection.
2665        //
2666        // Default: when the backend advertises support (Metal Q4KExperts),
2667        // run the single fused dispatch — saves 2 dispatches and the
2668        // entire round-trip through gate_out_stacked / up_out_stacked
2669        // scratch (≈4× [top_k, ffn] of intermediate bandwidth).
2670        //
2671        // Opt-out: `FERRUM_MOE_FUSED_GATE_UP_SILU=0` forces the legacy
2672        // 3-dispatch path. Used for A/B benchmarking and as a kill switch
2673        // if the fused kernel ever produces divergent outputs.
2674        // Cache the env-flag read once per process — the decode hot
2675        // path calls this fn ~48 layers × ~steps_per_run times.
2676        static FUSED_DISABLED: OnceLock<bool> = OnceLock::new();
2677        let fused_disabled = *FUSED_DISABLED
2678            .get_or_init(|| std::env::var("FERRUM_MOE_FUSED_GATE_UP_SILU").as_deref() == Ok("0"));
2679        let use_fused = B::supports_fused_moe_gate_up_silu() && !fused_disabled;
2680
2681        if use_fused {
2682            // 1+2+3 fused: silu_stacked = SiLU(gate · norm_out) * (up · norm_out)
2683            let t0 = stage_t0();
2684            B::gemv_quant_moe_id_gate_up_silu(
2685                ctx,
2686                &scratch.norm_out,
2687                gate_stacked,
2688                up_stacked,
2689                &scratch.ids_buf,
2690                &mut scratch.silu_stacked,
2691                top_k,
2692            )?;
2693            stage_end(t0, ctx, &DEC_SILU_US);
2694        } else {
2695            // 1. Batched gate gemv — broadcast input across top_k slots.
2696            //    src1 = norm_out (which has hidden floats at offset 0),
2697            //    src1_stride=0 → all slots read the same row.
2698            let t0 = stage_t0();
2699            B::gemv_quant_moe_id(
2700                ctx,
2701                &scratch.norm_out,
2702                gate_stacked,
2703                &scratch.ids_buf,
2704                &mut scratch.gate_out_stacked,
2705                top_k,
2706                0, // broadcast
2707            )?;
2708            stage_end(t0, ctx, &DEC_GATE_US);
2709
2710            // 2. Batched up gemv — also broadcast.
2711            let t0 = stage_t0();
2712            B::gemv_quant_moe_id(
2713                ctx,
2714                &scratch.norm_out,
2715                up_stacked,
2716                &scratch.ids_buf,
2717                &mut scratch.up_out_stacked,
2718                top_k,
2719                0,
2720            )?;
2721            stage_end(t0, ctx, &DEC_UP_US);
2722
2723            // 3. Stacked SiLU·gate → silu_stacked. Single dispatch covers
2724            //    all top_k slots — replaces the per-slot loop's
2725            //    (3 copy_slice + 1 silu_mul) × 8 = 32 dispatches.
2726            let t0 = stage_t0();
2727            B::silu_mul_stacked(
2728                ctx,
2729                &scratch.gate_out_stacked,
2730                &scratch.up_out_stacked,
2731                &mut scratch.silu_stacked,
2732                top_k,
2733                inter,
2734            )?;
2735            stage_end(t0, ctx, &DEC_SILU_US);
2736        }
2737
2738        // 4. Batched down gemv — per-slot input via src1_stride = inter.
2739        //    silu_stacked[k * inter ..] is the activation row for slot k.
2740        let t0 = stage_t0();
2741        B::gemv_quant_moe_id(
2742            ctx,
2743            &scratch.silu_stacked,
2744            down_stacked,
2745            &scratch.ids_buf,
2746            &mut scratch.down_out_stacked,
2747            top_k,
2748            inter,
2749        )?;
2750        stage_end(t0, ctx, &DEC_DOWN_US);
2751
2752        // 5. Fused weighted-sum + residual-add (+ optional next-layer
2753        //    rms_norm). Two paths:
2754        //
2755        //    * `next_norm_w = Some(_)` (cross-layer fusion): one kernel
2756        //      computes residual[i] += Σ_k w[k] · down[k, i] AND
2757        //      norm_out[i] = residual[i] · scale · next_norm_w[i].
2758        //      The next layer's leading rms_norm is skipped. Saves an
2759        //      additional dispatch per layer transition.
2760        //    * `next_norm_w = None` (last layer): just residual-add.
2761        let t0 = stage_t0();
2762        if let Some(nnw) = next_norm_w {
2763            B::weighted_sum_residual_norm_stacked(
2764                ctx,
2765                &scratch.down_out_stacked,
2766                &scratch.weights_buf,
2767                residual,
2768                nnw,
2769                &mut scratch.norm_out,
2770                top_k,
2771                h,
2772                eps,
2773            )?;
2774        } else {
2775            B::weighted_sum_residual_stacked(
2776                ctx,
2777                &scratch.down_out_stacked,
2778                &scratch.weights_buf,
2779                residual,
2780                top_k,
2781                h,
2782            )?;
2783        }
2784        stage_end(t0, ctx, &DEC_WSUM_US);
2785    }
2786
2787    Ok(())
2788}
2789
2790/// Batched MoE FFN for prefill (m > 1).
2791///
2792/// One pass through the expert dispatch — replaces the per-token loop
2793/// with three batched 2-D mul_mm_id dispatches (gate, up, down) where
2794/// each expert's slab of (token, slot) pairs runs as one gemm tile.
2795/// Per-layer dispatch count: ~6 (router + 3 mul_mm_id + silu + wsum)
2796/// independent of `tokens`. Compare to the decode-style stacked path
2797/// that emits ~10 per token.
2798///
2799/// Free function so the caller can split the borrow on `self` between
2800/// `moe_layers[li]` (immutable) and `scratch` (mutable).
2801#[allow(clippy::too_many_arguments)]
2802fn moe_forward_batched_prefill_impl<B: Backend>(
2803    ctx: &mut B::Context,
2804    moe_layer: &Qwen3MoeLayerState<B>,
2805    scratch: &mut Qwen3MoeScratch<B>,
2806    h: usize,
2807    inter: usize,
2808    top_k: usize,
2809    n_exp: usize,
2810    norm_topk_prob: bool,
2811    tokens: usize,
2812) -> Result<()> {
2813    let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
2814    let stage_t0 = || -> Option<std::time::Instant> {
2815        if prof {
2816            Some(std::time::Instant::now())
2817        } else {
2818            None
2819        }
2820    };
2821    let stage_end =
2822        |t0: Option<std::time::Instant>, ctx: &mut B::Context, us: &AtomicU64, n: &AtomicU64| {
2823            if let Some(t) = t0 {
2824                B::sync(ctx);
2825                us.fetch_add(
2826                    t.elapsed().as_micros() as u64,
2827                    std::sync::atomic::Ordering::Relaxed,
2828                );
2829                n.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2830            }
2831        };
2832
2833    // GPU-side routing: keep the whole pipeline device-resident. Two
2834    // dispatches replace the per-layer `B::sync + to_vec(router_logits)
2835    // + host route() + host compute_ids_tpe + write_back` round trip.
2836    //
2837    //   1. `route_topk_softmax` writes selected expert IDs (flat
2838    //      `[batch, top_k]`) into `selected_ids_buf` and the post-renorm
2839    //      combine weights directly into `weights_2d`.
2840    //   2. `compute_ids_tpe_gpu` buckets those pairs into `tpe_buf` and
2841    //      `ids_2d` using device-side atomic_fetch_add slot claims. The
2842    //      `ids_2d` row stride is the worst-case `tokens * top_k`; the
2843    //      consumer GEMM stops at `tpe[e]` so the over-strided columns
2844    //      cost only launch overhead, not real compute.
2845    //
2846    // `FERRUM_MOE_HOST_TOPK=1`        → legacy CPU softmax+topk+bucket
2847    // `FERRUM_MOE_DIRECT_DISPATCH=1`  → GPU topk but worst-case GEMM grid
2848    // (default)                       → GPU topk + indirect-dispatched GEMM
2849    //                                    (grid sized from max(tpe[e]))
2850    let use_gpu_topk = std::env::var("FERRUM_MOE_HOST_TOPK").as_deref() != Ok("1");
2851    let use_indirect_dispatch =
2852        use_gpu_topk && std::env::var("FERRUM_MOE_DIRECT_DISPATCH").as_deref() != Ok("1");
2853    let max_per_expert = if use_gpu_topk {
2854        let t0 = stage_t0();
2855        B::route_topk_softmax(
2856            ctx,
2857            &scratch.router_logits,
2858            &mut scratch.selected_ids_buf,
2859            &mut scratch.weights_2d,
2860            tokens,
2861            n_exp,
2862            top_k,
2863            norm_topk_prob,
2864        )?;
2865        B::compute_ids_tpe_gpu(
2866            ctx,
2867            &scratch.selected_ids_buf,
2868            &mut scratch.tpe_buf,
2869            &mut scratch.ids_2d,
2870            &mut scratch.gate_up_args_buf,
2871            &mut scratch.down_args_buf,
2872            tokens,
2873            n_exp,
2874            top_k,
2875            inter,
2876            h,
2877        )?;
2878        stage_end(
2879            t0,
2880            ctx,
2881            &MOE_PREFILL_HOST_TOPK_US,
2882            &MOE_PREFILL_HOST_TOPK_CALLS,
2883        );
2884        // Worst-case ids row stride; matches `dispatch_compute_ids_tpe`.
2885        tokens * top_k
2886    } else {
2887        use ferrum_kernels::moe_host::compute_ids_tpe;
2888        let t0 = stage_t0();
2889        B::sync(ctx);
2890        let logits_host = B::to_vec(&scratch.router_logits, tokens * n_exp);
2891        let route = crate::moe::router::route(&logits_host, tokens, n_exp, top_k, norm_topk_prob);
2892        let (tpe_host, ids_host, max_per_expert) =
2893            compute_ids_tpe(&route.expert_ids, n_exp, tokens, top_k);
2894        B::write_i32_into(&mut scratch.tpe_buf, &tpe_host);
2895        B::write_i32_into(&mut scratch.ids_2d, &ids_host);
2896        B::write_f32_into(&mut scratch.weights_2d, &route.expert_weights);
2897        stage_end(
2898            t0,
2899            ctx,
2900            &MOE_PREFILL_HOST_TOPK_US,
2901            &MOE_PREFILL_HOST_TOPK_CALLS,
2902        );
2903        max_per_expert
2904    };
2905
2906    let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
2907    let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
2908    let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
2909
2910    // 1. Batched gate gemm — one launch covers all (token, expert) pairs.
2911    //    src1 layout: [batch, ne11=1, K] (broadcast: each pair reads its
2912    //    token's row, slot index ignored).
2913    //    dst layout:  [batch, top_k, expert_inter] — natural.
2914    let t0 = stage_t0();
2915    if use_indirect_dispatch {
2916        B::gemm_quant_moe_id_indirect(
2917            ctx,
2918            &scratch.norm_out,
2919            gate_stacked,
2920            &scratch.ids_2d,
2921            &scratch.tpe_buf,
2922            &mut scratch.gate_out_stacked,
2923            &scratch.gate_up_args_buf,
2924            1, // ne11 = 1: broadcast
2925            top_k,
2926            max_per_expert,
2927            tokens,
2928        )?;
2929    } else {
2930        B::gemm_quant_moe_id(
2931            ctx,
2932            &scratch.norm_out,
2933            gate_stacked,
2934            &scratch.ids_2d,
2935            &scratch.tpe_buf,
2936            &mut scratch.gate_out_stacked,
2937            1,
2938            top_k,
2939            max_per_expert,
2940            tokens,
2941        )?;
2942    }
2943    stage_end(t0, ctx, &MOE_PREFILL_GATE_US, &MOE_PREFILL_GATE_CALLS);
2944
2945    // 2. Batched up gemm — same shape as gate.
2946    let t0 = stage_t0();
2947    if use_indirect_dispatch {
2948        B::gemm_quant_moe_id_indirect(
2949            ctx,
2950            &scratch.norm_out,
2951            up_stacked,
2952            &scratch.ids_2d,
2953            &scratch.tpe_buf,
2954            &mut scratch.up_out_stacked,
2955            &scratch.gate_up_args_buf,
2956            1,
2957            top_k,
2958            max_per_expert,
2959            tokens,
2960        )?;
2961    } else {
2962        B::gemm_quant_moe_id(
2963            ctx,
2964            &scratch.norm_out,
2965            up_stacked,
2966            &scratch.ids_2d,
2967            &scratch.tpe_buf,
2968            &mut scratch.up_out_stacked,
2969            1,
2970            top_k,
2971            max_per_expert,
2972            tokens,
2973        )?;
2974    }
2975    stage_end(t0, ctx, &MOE_PREFILL_UP_US, &MOE_PREFILL_UP_CALLS);
2976
2977    // 3. SiLU·gate over [tokens * top_k, expert_inter] flat layout.
2978    let total_pairs = tokens * top_k;
2979    let t0 = stage_t0();
2980    B::silu_mul_batched(
2981        ctx,
2982        &scratch.gate_out_stacked,
2983        &scratch.up_out_stacked,
2984        &mut scratch.silu_stacked,
2985        total_pairs,
2986        inter,
2987    )?;
2988    stage_end(t0, ctx, &MOE_PREFILL_SILU_US, &MOE_PREFILL_SILU_CALLS);
2989
2990    // 4. Batched down gemm — src1 is [batch, top_k, expert_inter] from
2991    //    silu_stacked. ne11 = top_k → each pair reads its own row.
2992    let t0 = stage_t0();
2993    if use_indirect_dispatch {
2994        B::gemm_quant_moe_id_indirect(
2995            ctx,
2996            &scratch.silu_stacked,
2997            down_stacked,
2998            &scratch.ids_2d,
2999            &scratch.tpe_buf,
3000            &mut scratch.down_out_stacked,
3001            &scratch.down_args_buf,
3002            top_k, // ne11 = top_k: per-slot
3003            top_k,
3004            max_per_expert,
3005            tokens,
3006        )?;
3007    } else {
3008        B::gemm_quant_moe_id(
3009            ctx,
3010            &scratch.silu_stacked,
3011            down_stacked,
3012            &scratch.ids_2d,
3013            &scratch.tpe_buf,
3014            &mut scratch.down_out_stacked,
3015            top_k,
3016            top_k,
3017            max_per_expert,
3018            tokens,
3019        )?;
3020    }
3021    stage_end(t0, ctx, &MOE_PREFILL_DOWN_US, &MOE_PREFILL_DOWN_CALLS);
3022
3023    // 5. Per-batch weighted sum: moe_out[b, h] = Σ_k w[b,k] · down[b,k,h]
3024    let t0 = stage_t0();
3025    B::weighted_sum_batched(
3026        ctx,
3027        &scratch.down_out_stacked,
3028        &scratch.weights_2d,
3029        &mut scratch.moe_out,
3030        tokens,
3031        top_k,
3032        h,
3033    )?;
3034    stage_end(t0, ctx, &MOE_PREFILL_WSUM_US, &MOE_PREFILL_WSUM_CALLS);
3035
3036    Ok(())
3037}
3038
3039/// Batched MoE FFN for the **small-m decode** range (typically c=2..32).
3040///
3041/// Mirrors llama.cpp's `kernel_mul_mv_id` strategy: hold the dispatch
3042/// count flat as concurrency scales by emitting **one** batched GEMV
3043/// per linear (gate / up / down) that covers all `m * top_k`
3044/// (token, expert) pairs in a single Metal launch. Replaces the
3045/// per-token outer loop in `forward_layer` (which emitted ~5
3046/// dispatches × m tokens per layer) with a fixed-shape pipeline.
3047///
3048/// Compared to [`moe_forward_batched_prefill_impl`]:
3049///   * no `compute_ids_tpe_gpu` bucketing kernel (the new pair-indexed
3050///     GEMV reads `selected_ids_buf` directly)
3051///   * uses GEMV not GEMM (better tile utilisation when tokens-per-expert
3052///     is small — at c=16 with top_k=8 each expert sees ~1-3 token rows,
3053///     well below the simdgroup_matmul tile width)
3054///   * fewer Metal dispatches per layer (5: route + 3 gemv + silu + wsum)
3055///
3056/// Per-layer dispatch budget: 5 (independent of m). At c=16 / 48 layers
3057/// that's 240 dispatches per decode step vs the per-token loop's ~3,840.
3058#[allow(clippy::too_many_arguments)]
3059fn moe_forward_batched_decode_impl<B: Backend>(
3060    ctx: &mut B::Context,
3061    moe_layer: &Qwen3MoeLayerState<B>,
3062    scratch: &mut Qwen3MoeScratch<B>,
3063    h: usize,
3064    inter: usize,
3065    top_k: usize,
3066    n_exp: usize,
3067    norm_topk_prob: bool,
3068    tokens: usize,
3069) -> Result<()> {
3070    let prof = std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok();
3071    let stage_t0 = || -> Option<std::time::Instant> {
3072        if prof {
3073            Some(std::time::Instant::now())
3074        } else {
3075            None
3076        }
3077    };
3078    let stage_end = |t0: Option<std::time::Instant>, ctx: &mut B::Context, c: &AtomicU64| {
3079        if let Some(t) = t0 {
3080            B::sync(ctx);
3081            c.fetch_add(
3082                t.elapsed().as_micros() as u64,
3083                std::sync::atomic::Ordering::Relaxed,
3084            );
3085        }
3086    };
3087
3088    let total_pairs = tokens * top_k;
3089
3090    // 1. Single batched router pass — fills selected_ids_buf [m * top_k]
3091    //    and weights_2d [m * top_k] in one Metal dispatch.
3092    let t0 = stage_t0();
3093    B::route_topk_softmax(
3094        ctx,
3095        &scratch.router_logits,
3096        &mut scratch.selected_ids_buf,
3097        &mut scratch.weights_2d,
3098        tokens,
3099        n_exp,
3100        top_k,
3101        norm_topk_prob,
3102    )?;
3103    stage_end(t0, ctx, &MOE_BATCHED_DECODE_ROUTE_US);
3104
3105    let gate_stacked = moe_layer.experts.gate_stacked.as_ref().unwrap();
3106    let up_stacked = moe_layer.experts.up_stacked.as_ref().unwrap();
3107    let down_stacked = moe_layer.experts.down_stacked.as_ref().unwrap();
3108
3109    // 2+3+4. Fused gate+up+silu — single Metal dispatch covers all
3110    // m*top_k pairs. Falls back to the 3-dispatch sequence on backends
3111    // that don't have the fused-batched kernel.
3112    if B::supports_batched_moe_gate_up_silu() {
3113        let t0 = stage_t0();
3114        B::gemv_quant_moe_id_gate_up_silu_batched(
3115            ctx,
3116            &scratch.norm_out,
3117            gate_stacked,
3118            up_stacked,
3119            &scratch.selected_ids_buf,
3120            &mut scratch.silu_stacked,
3121            tokens,
3122            top_k,
3123            h, // outer stride: K floats per token
3124            0, // inner stride: 0 (slots within a token broadcast)
3125        )?;
3126        // Charge the whole fused step to the SiLU bucket — keeps the
3127        // profile counter additive with the unfused path's silu line.
3128        stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
3129    } else {
3130        // 2. Batched gate gemv — one launch covers all m*top_k pairs.
3131        let t0 = stage_t0();
3132        B::gemv_quant_moe_id_batched(
3133            ctx,
3134            &scratch.norm_out,
3135            gate_stacked,
3136            &scratch.selected_ids_buf,
3137            &mut scratch.gate_out_stacked,
3138            tokens,
3139            top_k,
3140            h,
3141            0,
3142        )?;
3143        stage_end(t0, ctx, &MOE_BATCHED_DECODE_GATE_US);
3144
3145        // 3. Batched up gemv.
3146        let t0 = stage_t0();
3147        B::gemv_quant_moe_id_batched(
3148            ctx,
3149            &scratch.norm_out,
3150            up_stacked,
3151            &scratch.selected_ids_buf,
3152            &mut scratch.up_out_stacked,
3153            tokens,
3154            top_k,
3155            h,
3156            0,
3157        )?;
3158        stage_end(t0, ctx, &MOE_BATCHED_DECODE_UP_US);
3159
3160        // 4. SiLU·gate.
3161        let t0 = stage_t0();
3162        B::silu_mul_batched(
3163            ctx,
3164            &scratch.gate_out_stacked,
3165            &scratch.up_out_stacked,
3166            &mut scratch.silu_stacked,
3167            total_pairs,
3168            inter,
3169        )?;
3170        stage_end(t0, ctx, &MOE_BATCHED_DECODE_SILU_US);
3171    }
3172
3173    // 5. Batched down gemv — src1 = silu_stacked [m, top_k, ffn]: each
3174    //    pair has its own row, outer = top_k * ffn, inner = ffn.
3175    let t0 = stage_t0();
3176    B::gemv_quant_moe_id_batched(
3177        ctx,
3178        &scratch.silu_stacked,
3179        down_stacked,
3180        &scratch.selected_ids_buf,
3181        &mut scratch.down_out_stacked,
3182        tokens,
3183        top_k,
3184        top_k * inter, // outer: top_k * ffn floats per token
3185        inter,         // inner: ffn floats per slot
3186    )?;
3187    stage_end(t0, ctx, &MOE_BATCHED_DECODE_DOWN_US);
3188
3189    // 6. Per-token weighted sum across slots → moe_out [m, h]. Caller
3190    //    does residual += moe_out at the end of forward_layer.
3191    let t0 = stage_t0();
3192    B::weighted_sum_batched(
3193        ctx,
3194        &scratch.down_out_stacked,
3195        &scratch.weights_2d,
3196        &mut scratch.moe_out,
3197        tokens,
3198        top_k,
3199        h,
3200    )?;
3201    stage_end(t0, ctx, &MOE_BATCHED_DECODE_WSUM_US);
3202
3203    Ok(())
3204}
3205
3206/// Build a stub Linear<B> with the given shape but zero weights. Used to
3207/// fill the dense `gate_up_proj` / `down_proj` slots in `LlamaFamilyLayer`
3208/// for MoE models — those slots are never invoked because the MoE FFN
3209/// path runs through `moe_layer.experts` instead. The stub's only purpose
3210/// is to satisfy the struct's type signature with minimal memory cost.
3211fn stub_linear<B: Backend>(
3212    out_features: usize,
3213    in_features: usize,
3214) -> Box<dyn ferrum_quantization::Linear<B>> {
3215    // Zero-init: out_features * in_features f32. For a 30B-A3B layer
3216    // this is 2*768*2048 = 3.1M f32 = 12 MB → fine; per-layer overhead
3217    // ≈ 12 MB × 48 = 576 MB. Marginal vs the experts (~16 GB).
3218    let zeros = vec![0.0f32; out_features * in_features];
3219    Box::new(ferrum_quantization::DenseLinear::<B>::from_rows(
3220        &zeros,
3221        out_features,
3222        in_features,
3223    ))
3224}
3225
3226fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
3227    let hd = cfg.head_dim;
3228    let half = hd / 2;
3229    let max = cfg.max_seq_len;
3230    let mut cos = vec![0.0f32; max * half];
3231    let mut sin = vec![0.0f32; max * half];
3232    for pos in 0..max {
3233        for i in 0..half {
3234            let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
3235            let angle = pos as f64 * freq;
3236            cos[pos * half + i] = angle.cos() as f32;
3237            sin[pos * half + i] = angle.sin() as f32;
3238        }
3239    }
3240    RopeCache {
3241        cos: B::from_slice(&cos),
3242        sin: B::from_slice(&sin),
3243    }
3244}