Skip to main content

ferrum_models/models/
llama_family.rs

1//! Llama-family decoder model as explicit code.
2//!
3//! Covers all "standard Llama-style GQA + SwiGLU + RoPE" decoders:
4//!   - Llama / Llama-2 / Llama-3  (no QK-norm)
5//!   - Qwen2 / Qwen2.5            (no QK-norm, structurally Llama)
6//!   - Qwen3                      (QK-norm per head, larger rope_theta)
7//!   - Mistral                    (sliding-window attention — not yet
8//!                                 supported on the forward path; will
9//!                                 require an `AttnConfig.sliding_window`
10//!                                 field + shader support in Phase D)
11//!
12//! Variant differences are controlled by `LlamaFamilyConfig::has_qk_norm`
13//! and RoPE theta. Weight loading accepts both fused (`qkv_proj`,
14//! `gate_up_proj`) and split (`q_proj`+`k_proj`+`v_proj`,
15//! `gate_proj`+`up_proj`) projection layouts — the loader (e.g.
16//! `CandleVarBuilderLoader`) fuses split weights on load so model code
17//! sees a uniform `qkv_proj` / `gate_up_proj` Linear.
18
19use std::collections::HashMap;
20use std::sync::atomic::AtomicU64;
21
22use ferrum_kernels::backend::{Backend, KvCache};
23
24static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
25static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
26static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
27static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
28static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
29static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
30static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
31static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
32static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
33static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
34use ferrum_quantization::{Linear, WeightLoader};
35use ferrum_types::Result;
36
37use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
38
39/// Full Qwen3 architecture config (everything the model code needs, not just
40/// the engine-facing subset in `LlmRuntimeConfig`).
41#[derive(Clone, Debug, PartialEq)]
42pub struct LlamaFamilyConfig {
43    pub hidden_size: usize,
44    pub intermediate_size: usize,
45    pub num_heads: usize,
46    pub num_kv_heads: usize,
47    pub head_dim: usize,
48    pub num_layers: usize,
49    pub vocab_size: usize,
50    pub max_seq_len: usize,
51    pub rms_norm_eps: f32,
52    pub rope_theta: f64,
53    /// Whether the checkpoint has `q_norm` / `k_norm` per layer. All known
54    /// Qwen3 checkpoints do; some derivatives may strip them.
55    pub has_qk_norm: bool,
56    /// Sliding-window attention size. `0` disables (full causal).
57    /// Mistral v0.1 sets 4096; Mistral v0.2+ removed the limitation (0).
58    pub sliding_window: usize,
59}
60
61impl LlamaFamilyConfig {
62    pub fn to_runtime(&self) -> LlmRuntimeConfig {
63        LlmRuntimeConfig {
64            hidden_size: self.hidden_size,
65            num_layers: self.num_layers,
66            num_kv_heads: self.num_kv_heads,
67            head_dim: self.head_dim,
68            vocab_size: self.vocab_size,
69            max_seq_len: self.max_seq_len,
70        }
71    }
72
73    /// Build config from a `ModelDefinition`, shared field extraction.
74    /// Variant-specific constructors below set `has_qk_norm` and fall back
75    /// to different `rope_theta` defaults when the checkpoint doesn't set one.
76    fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
77        let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
78        let head_dim = def
79            .extra_params
80            .get("head_dim")
81            .and_then(|v| v.as_u64())
82            .map(|v| v as usize)
83            .unwrap_or(def.hidden_size / def.num_attention_heads);
84        // Mistral / Gemma: "sliding_window" may be null (v0.2+) or a positive
85        // integer (v0.1). Non-null value passes through; missing/null → 0.
86        let sliding_window = def
87            .extra_params
88            .get("sliding_window")
89            .and_then(|v| v.as_u64())
90            .map(|v| v as usize)
91            .unwrap_or(0);
92
93        LlamaFamilyConfigBase {
94            hidden_size: def.hidden_size,
95            intermediate_size: def.intermediate_size,
96            num_heads: def.num_attention_heads,
97            num_kv_heads,
98            head_dim,
99            num_layers: def.num_hidden_layers,
100            vocab_size: def.vocab_size,
101            max_seq_len: def.max_position_embeddings,
102            rms_norm_eps: def.norm_eps as f32,
103            rope_theta_opt: def.rope_theta,
104            sliding_window,
105        }
106    }
107
108    fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
109        Self {
110            hidden_size: b.hidden_size,
111            intermediate_size: b.intermediate_size,
112            num_heads: b.num_heads,
113            num_kv_heads: b.num_kv_heads,
114            head_dim: b.head_dim,
115            num_layers: b.num_layers,
116            vocab_size: b.vocab_size,
117            max_seq_len: b.max_seq_len,
118            rms_norm_eps: b.rms_norm_eps,
119            rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
120            has_qk_norm,
121            sliding_window: b.sliding_window,
122        }
123    }
124
125    /// Qwen3: QK-norm on, rope_theta default 1e6.
126    pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
127        Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
128    }
129
130    /// Llama / Llama-2 / Llama-3: no QK-norm; rope_theta varies by version
131    /// (10k for Llama-2, 500k for Llama-3.1+) — use the checkpoint value or
132    /// fall back to the most common modern value.
133    pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
134        Self::from_base(Self::from_def_base(def), 500_000.0, false)
135    }
136
137    /// Qwen2 / Qwen2.5: structurally Llama; no QK-norm; rope_theta default 1e6.
138    pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
139        Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
140    }
141
142    /// Mistral: no QK-norm; `rope_theta` commonly 10_000 (v0.1 / v0.2).
143    /// Picks up `sliding_window` from the checkpoint's config.json
144    /// (Mistral v0.1: 4096; Mistral v0.2+: 0 / null).
145    pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
146        Self::from_base(Self::from_def_base(def), 10_000.0, false)
147    }
148}
149
150struct LlamaFamilyConfigBase {
151    hidden_size: usize,
152    intermediate_size: usize,
153    num_heads: usize,
154    num_kv_heads: usize,
155    head_dim: usize,
156    num_layers: usize,
157    vocab_size: usize,
158    max_seq_len: usize,
159    rms_norm_eps: f32,
160    rope_theta_opt: Option<f64>,
161    sliding_window: usize,
162}
163
164/// Per-layer weights. `Box<dyn Linear<B>>` means each projection can be
165/// Dense / GPTQ / AWQ / GGUF without the surrounding code caring.
166pub struct LlamaFamilyLayer<B: Backend> {
167    pub input_ln_w: B::Buffer,
168    pub qkv_proj: Box<dyn Linear<B>>,
169    /// QK-norm weight per head: `[head_dim]`. Optional for non-Qwen3 derivatives.
170    pub q_norm_w: Option<B::Buffer>,
171    pub k_norm_w: Option<B::Buffer>,
172    pub o_proj: Box<dyn Linear<B>>,
173    pub post_ln_w: B::Buffer,
174    pub gate_up_proj: Box<dyn Linear<B>>,
175    pub down_proj: Box<dyn Linear<B>>,
176}
177
178/// Precomputed RoPE cos/sin tables (shape `[max_seq, head_dim / 2]` each).
179pub struct RopeCache<B: Backend> {
180    pub cos: B::Buffer,
181    pub sin: B::Buffer,
182}
183
184/// Reusable per-layer scratch buffers sized for `max_tokens` tokens of a
185/// single forward pass (prefill or decode step).
186///
187/// Sized lazily on first use so tiny decode steps don't pay for prefill-sized
188/// buffers. Grows monotonically when a larger prefill arrives.
189pub struct LlamaFamilyScratch<B: Backend> {
190    /// Residual stream — wrapped in Option so decode_internal can
191    /// `.take()` it without needing an alloc placeholder.
192    ///
193    /// Why this matters for graph capture: the old pattern was
194    /// `mem::replace(&mut scratch.residual, B::alloc(1))` which creates a
195    /// 1-element buffer at every decode step. When graph capture is on,
196    /// that alloc-during-capture + drop-after-capture pair surfaces as
197    /// cuMemFreeAsync(INVALID_VALUE) because the free tries to release a
198    /// pointer the captured graph may still reference. Option::take leaves
199    /// None and moves the real buffer into a local — no spurious alloc.
200    pub residual: Option<B::Buffer>,
201    pub norm_out: B::Buffer,
202    pub qkv_out: B::Buffer,
203    // ── Per-item scratch for batched decode path ──────────────────────
204    // decode_batch_internal runs tokens=M batched ops for the GEMM-heavy
205    // half (norm, qkv_proj, split_qkv, o_proj, post_norm, gate_up, silu,
206    // down, residual_add) but must loop per-item for rope + KV append +
207    // attention (each item has its own KV cache at a different kv_len).
208    // These single-item buffers hold item i's slice during that loop.
209    /// Item-scope q_buf slice, sized `q_dim`.
210    pub q_single: B::Buffer,
211    pub k_single: B::Buffer,
212    pub v_single: B::Buffer,
213    pub q_head_major_single: B::Buffer,
214    pub k_head_major_single: B::Buffer,
215    pub v_head_major_single: B::Buffer,
216    pub attn_head_major_single: B::Buffer,
217    pub attn_flat_single: B::Buffer,
218    /// Batched logits output, sized `max_tokens * vocab_size`. Used only
219    /// in decode_batch; prefill/single-decode use the regular `logits`.
220    pub batch_logits: B::Buffer,
221    /// Token-major Q/K/V right after `split_qkv`. Stride: heads * hd per row.
222    pub q_buf: B::Buffer,
223    pub k_buf: B::Buffer,
224    pub v_buf: B::Buffer,
225    /// Head-major Q produced by `qk_norm_rope` — fed into `flash_attention`.
226    pub q_head_major: B::Buffer,
227    /// Head-major K/V staging — produced by `qk_norm_rope`, consumed by
228    /// `kv_cache_append_head_major` (no reuse after append).
229    pub k_head_major: B::Buffer,
230    pub v_head_major: B::Buffer,
231    /// Head-major attention output from `flash_attention`.
232    pub attn_head_major_out: B::Buffer,
233    /// Token-major attention output after `transpose_head_to_token`.
234    pub attn_flat: B::Buffer,
235    pub o_proj_out: B::Buffer,
236    pub gate_up_out: B::Buffer,
237    pub silu_out: B::Buffer,
238    pub mlp_out: B::Buffer,
239    /// Paged batched dispatch scratch (Phase 4b). Sized for
240    /// `FERRUM_PAGED_MAX_SEQS × q_dim` so multi-seq decode can fan
241    /// in M items' Q into a single buffer for one batched
242    /// `paged_decode_attention(num_seqs=M)` call. `None` when paged
243    /// mode is off.
244    pub paged_batch_q: Option<B::Buffer>,
245    pub paged_batch_o: Option<B::Buffer>,
246    /// Stacked per-seq block tables for batched paged dispatch.
247    /// Layout: `[max_M, max_blocks_per_seq]` u32. Written
248    /// host-side per decode_batch step.
249    pub paged_batch_block_tables: Option<B::Buffer>,
250    /// Stacked per-seq context lengths for batched paged dispatch
251    /// (`[max_M]` u32).
252    pub paged_batch_context_lens: Option<B::Buffer>,
253    /// `max_blocks_per_seq` value baked into the stacked block_tables
254    /// stride. Set when `paged_batch_block_tables` is allocated.
255    pub paged_max_blocks_per_seq: usize,
256    /// Last token's hidden state (`[h]`). For prefill this is populated via
257    /// `copy_slice(residual, (seq_len-1)*h, ..)`; for decode `residual` already
258    /// holds only 1 row so `last_hidden` is unused on that path.
259    pub last_hidden: B::Buffer,
260    /// Final-norm output for the last token (`[h]`).
261    pub last_normed: B::Buffer,
262    /// lm_head logits (`[vocab]`).
263    pub logits: B::Buffer,
264    /// The max tokens-per-step this scratch has been sized for.
265    pub max_tokens: usize,
266}
267
268impl<B: Backend> LlamaFamilyScratch<B> {
269    fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
270        let h = cfg.hidden_size;
271        let im = cfg.intermediate_size;
272        let q_dim = cfg.num_heads * cfg.head_dim;
273        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
274        let qkv_dim = q_dim + 2 * kv_dim;
275        let t = max_tokens;
276        Self {
277            residual: Some(B::alloc(t * h)),
278            norm_out: B::alloc(t * h),
279            qkv_out: B::alloc(t * qkv_dim),
280            q_buf: B::alloc(t * q_dim),
281            k_buf: B::alloc(t * kv_dim),
282            v_buf: B::alloc(t * kv_dim),
283            q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
284            k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
285            v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
286            attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
287            attn_flat: B::alloc(t * q_dim),
288            o_proj_out: B::alloc(t * h),
289            gate_up_out: B::alloc(t * 2 * im),
290            silu_out: B::alloc(t * im),
291            mlp_out: B::alloc(t * h),
292            last_hidden: B::alloc(h),
293            last_normed: B::alloc(h),
294            logits: B::alloc(cfg.vocab_size),
295            q_single: B::alloc(q_dim),
296            k_single: B::alloc(kv_dim),
297            v_single: B::alloc(kv_dim),
298            q_head_major_single: B::alloc(q_dim),
299            k_head_major_single: B::alloc(kv_dim),
300            v_head_major_single: B::alloc(kv_dim),
301            attn_head_major_single: B::alloc(q_dim),
302            attn_flat_single: B::alloc(q_dim),
303            batch_logits: B::alloc(t * cfg.vocab_size),
304            // Paged batched dispatch scratch. None until `enable_paged_batch`
305            // is called from `ensure_kv` once the model knows max_seqs +
306            // max_blocks_per_seq. This avoids paying the alloc cost when
307            // paged mode is off.
308            paged_batch_q: None,
309            paged_batch_o: None,
310            paged_batch_block_tables: None,
311            paged_batch_context_lens: None,
312            paged_max_blocks_per_seq: 0,
313            max_tokens: t,
314        }
315    }
316
317    /// Allocate scratch for batched paged dispatch (Phase 4b). Called
318    /// lazily from `ensure_kv` once paged mode is enabled and we know
319    /// the pool dimensions. Idempotent.
320    fn enable_paged_batch(
321        &mut self,
322        cfg: &LlamaFamilyConfig,
323        max_seqs: usize,
324        max_blocks_per_seq: usize,
325    ) {
326        if self.paged_batch_q.is_some() {
327            return;
328        }
329        let q_dim = cfg.num_heads * cfg.head_dim;
330        self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
331        self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
332        self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
333        self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
334        self.paged_max_blocks_per_seq = max_blocks_per_seq;
335    }
336}
337
338/// Qwen3 model — decoder-only LLM, one per (backend, weights) combination.
339///
340/// Holds all parameters, scratch space, RoPE cache, and per-sequence KV caches.
341pub struct LlamaFamilyModel<B: Backend> {
342    pub cfg: LlamaFamilyConfig,
343    pub runtime_cfg: LlmRuntimeConfig,
344
345    /// Token embedding table. `None` for backbone-only models (e.g. the
346    /// Qwen3-TTS Talker, which embeds inputs externally and feeds via
347    /// `prefill_from_embeds`).
348    pub embed: Option<B::Buffer>,
349    pub layers: Vec<LlamaFamilyLayer<B>>,
350    pub final_norm_w: B::Buffer,
351    /// LM output head. `None` for backbone-only models.
352    pub lm_head: Option<Box<dyn Linear<B>>>,
353
354    pub rope: RopeCache<B>,
355    pub scratch: LlamaFamilyScratch<B>,
356
357    /// Per-sequence KV caches, one `Vec<KvCache<B>>` of length `num_layers`.
358    ///
359    /// Two layouts overlay this same map:
360    /// - **Contiguous mode** (default): each cache holds its own
361    ///   `[num_kv_heads, capacity, head_dim]` k/v buffers.
362    /// - **Paged mode** (`FERRUM_METAL_PAGED_KV=1`): k/v are unused
363    ///   placeholders; the real K/V live in [`Self::paged_pools`] and
364    ///   the cache's `block_table` + `context_lens` index into them.
365    pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
366    /// Free pool of pre-allocated KV cache slots. Released caches return
367    /// here instead of being dropped, so their device pointers stay valid
368    /// across requests — critical for graph capture (pointers baked into
369    /// the captured graph would otherwise dangle).
370    kv_free_pool: Vec<Vec<KvCache<B>>>,
371
372    // ── Paged-KV multi-seq state (Phase 4) ─────────────────────────────
373    //
374    // Only populated when `FERRUM_METAL_PAGED_KV=1`. When set, every
375    // `kv_caches` entry becomes a "view" into the shared pool: its
376    // `k` / `v` buffers are placeholders; reads / writes go through
377    // `paged_pools[layer].k` / `.v` indexed via the cache's
378    // `block_table`. Multiple cache_ids share the same pool, with
379    // disjoint physical block sets owned by `paged_block_alloc`.
380    //
381    /// Shared K/V pools, one per layer. Sized at model load time for the
382    /// configured `MAX_RUNNING_REQUESTS × max_blocks_per_seq` blocks.
383    pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
384    /// Block allocator hands out physical block indices from the pool.
385    /// `Mutex` because the engine can call `ensure_kv` / `release_kv`
386    /// from multiple threads in concurrent serving.
387    pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
388
389    // ── Graph capture state (CUDA only; harmless no-op on other backends) ──
390    /// Count of eager decode steps run so far. After `GRAPH_WARMUP`, the
391    /// next step captures the decode flow as a graph.
392    graph_warmup: usize,
393    /// True if capture was attempted but failed (e.g. backend doesn't
394    /// support graph capture). Stops further attempts, falls back to eager.
395    graph_capture_failed: bool,
396}
397
398impl<B: Backend> LlamaFamilyModel<B> {
399    /// Build a Qwen3 model from weights provided by the loader.
400    ///
401    /// The loader decides per-projection whether to instantiate DenseLinear,
402    /// GptqLinear, etc. — this code doesn't care.
403    pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
404        // Invalidate any graph from a previously-loaded model. The captured
405        // graph references the old model's scratch buffers; a fresh model
406        // gets fresh scratch, so reusing the graph would read/write freed
407        // pointers. Matters for test suites where multiple models coexist.
408        {
409            let mut ctx = B::new_context();
410            B::reset_graph(&mut ctx);
411        }
412        let rope = build_rope_cache::<B>(&cfg);
413        let scratch = LlamaFamilyScratch::alloc(&cfg, 1); // decode-sized; prefill resizes
414
415        // Embedding: plain tensor (no projection math, just lookup).
416        let embed = loader.load_tensor("model.embed_tokens.weight")?;
417
418        // Per-layer weights.
419        let mut layers = Vec::with_capacity(cfg.num_layers);
420        for li in 0..cfg.num_layers {
421            let prefix = format!("model.layers.{li}");
422            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
423            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
424            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
425            let post_ln_w =
426                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
427            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
428            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
429
430            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
431                let q = loader
432                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
433                    .ok();
434                let k = loader
435                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
436                    .ok();
437                (q, k)
438            } else {
439                (None, None)
440            };
441
442            layers.push(LlamaFamilyLayer {
443                input_ln_w,
444                qkv_proj,
445                q_norm_w,
446                k_norm_w,
447                o_proj,
448                post_ln_w,
449                gate_up_proj,
450                down_proj,
451            });
452        }
453
454        let final_norm_w = loader.load_tensor("model.norm.weight")?;
455
456        // LM head: either dedicated `lm_head.weight` or tied to embedding.
457        // Many models (Qwen3-4B, Llama-3.2-1B, some Qwen2.5) use TIED
458        // embeddings — lm_head shares weights with model.embed_tokens. When
459        // no dedicated lm_head tensor exists, re-load the embed tensor as a
460        // DenseLinear. This duplicates the buffer (memory cost = vocab*h*2
461        // bytes, e.g. ~770MB for Qwen3-4B) but keeps the Linear trait's
462        // owned-weights invariant. Sharing via Arc is a future optimisation.
463        let lm_head = if loader.has_tensor("lm_head.weight") {
464            loader.load_linear("lm_head")?
465        } else {
466            tracing::info!(
467                "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
468            );
469            let as_linear = loader.load_linear("model.embed_tokens")?;
470            // Sanity check: shape must be [vocab, hidden].
471            if as_linear.out_features() != cfg.vocab_size
472                || as_linear.in_features() != cfg.hidden_size
473            {
474                return Err(ferrum_types::FerrumError::model(format!(
475                    "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
476                    as_linear.out_features(),
477                    as_linear.in_features(),
478                    cfg.vocab_size,
479                    cfg.hidden_size
480                )));
481            }
482            as_linear
483        };
484
485        let runtime_cfg = cfg.to_runtime();
486        Ok(Self {
487            cfg,
488            runtime_cfg,
489            embed: Some(embed),
490            layers,
491            final_norm_w,
492            lm_head: Some(lm_head),
493            rope,
494            scratch,
495            kv_caches: HashMap::new(),
496            kv_free_pool: Vec::new(),
497            paged_pools: None,
498            paged_block_alloc: None,
499            graph_warmup: 0,
500            graph_capture_failed: false,
501        })
502    }
503
504    /// Build a backbone-only Qwen3 transformer stack (no embed, no lm_head).
505    ///
506    /// Intended for composing the transformer inside a larger model where
507    /// embedding and output-head logic differs from the standard LLM path —
508    /// e.g. Qwen3-TTS Talker uses dual text/codec embeddings with a projection
509    /// MLP, and a codec_head output. The caller drives forward via
510    /// `prefill_from_embeds` / `decode_from_embed`.
511    ///
512    /// Loader must provide: per-layer weights under `model.layers.{i}.*` and
513    /// the final `model.norm.weight`. `model.embed_tokens` and `lm_head`
514    /// are NOT read.
515    pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
516        // See `new` — invalidate stale graph referring to prior model's scratch.
517        {
518            let mut ctx = B::new_context();
519            B::reset_graph(&mut ctx);
520        }
521        let rope = build_rope_cache::<B>(&cfg);
522        let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
523
524        let mut layers = Vec::with_capacity(cfg.num_layers);
525        for li in 0..cfg.num_layers {
526            let prefix = format!("model.layers.{li}");
527            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
528            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
529            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
530            let post_ln_w =
531                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
532            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
533            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
534
535            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
536                let q = loader
537                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
538                    .ok();
539                let k = loader
540                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
541                    .ok();
542                (q, k)
543            } else {
544                (None, None)
545            };
546
547            layers.push(LlamaFamilyLayer {
548                input_ln_w,
549                qkv_proj,
550                q_norm_w,
551                k_norm_w,
552                o_proj,
553                post_ln_w,
554                gate_up_proj,
555                down_proj,
556            });
557        }
558
559        let final_norm_w = loader.load_tensor("model.norm.weight")?;
560
561        let runtime_cfg = cfg.to_runtime();
562        Ok(Self {
563            cfg,
564            runtime_cfg,
565            embed: None,
566            layers,
567            final_norm_w,
568            lm_head: None,
569            rope,
570            scratch,
571            kv_caches: HashMap::new(),
572            kv_free_pool: Vec::new(),
573            paged_pools: None,
574            paged_block_alloc: None,
575            graph_warmup: 0,
576            graph_capture_failed: false,
577        })
578    }
579
580    /// Grow scratch buffers if `tokens` exceeds the current sizing.
581    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
582        if self.scratch.max_tokens < tokens {
583            // Any captured decode graph holds pointers to the old scratch
584            // buffers; those are about to be freed. Invalidate first so the
585            // next decode falls back to eager + re-captures with fresh ptrs.
586            // Critical for multi-turn chat (turn N+1's prefill may grow scratch).
587            {
588                let mut ctx = B::new_context();
589                B::reset_graph(&mut ctx);
590            }
591            self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
592            self.graph_warmup = 0;
593            self.graph_capture_failed = false;
594        }
595    }
596
597    /// Ensure per-layer KV caches exist for `cache_id`, pre-allocated to
598    /// `max_seq_len` slots per head. Enables the in-place
599    /// `kv_cache_append_head_major` path — no realloc per layer.
600    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
601        if self.kv_caches.contains_key(cache_id) {
602            return;
603        }
604        let nkv = self.cfg.num_kv_heads;
605        let hd = self.cfg.head_dim;
606        // KV capacity defaults to a chat-friendly 4096 to keep the working
607        // set sane on a 32 GB Mac (a 48-layer / 4-kv-head / 128-head_dim
608        // model spends ~786 MB on 4096 KV slots, vs ~6 GB on the model's
609        // declared 32K which would push the 17 GB Qwen3-30B-A3B model into
610        // swap). `FERRUM_KV_CAPACITY=N` overrides; clamp to the model's
611        // declared max so we never lie to the model about its window.
612        let model_max = self.cfg.max_seq_len;
613        const DEFAULT_KV_CAPACITY: usize = 4096;
614        let max = std::env::var("FERRUM_KV_CAPACITY")
615            .ok()
616            .and_then(|s| s.parse::<usize>().ok())
617            .map(|cap| cap.min(model_max))
618            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
619
620        // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches every cache
621        // for this model into block-table-indirect layout. Kernels from
622        // PR #68 (decode read) + PR #69 (decode write) handle the
623        // indirect addressing; the LlamaFamily decode path below picks
624        // them up automatically by checking `cache.block_size > 0`.
625        //
626        // Pool sizing: round capacity up to a multiple of block_size,
627        // identity-assign logical→physical block. Memory footprint is
628        // the same as contiguous (within block_size rounding); the
629        // benefit only shows up under multi-seq sharing in Phase 4+.
630        let paged = std::env::var("FERRUM_METAL_PAGED_KV")
631            .map(|v| v == "1")
632            .unwrap_or(false);
633        const PAGED_BLOCK_SIZE: usize = 16;
634
635        // Phase 4 shared-pool sizing. The pool sees ALL concurrent
636        // sequences; per-cache_id state just owns indices into it.
637        let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
638            .ok()
639            .and_then(|s| s.parse::<usize>().ok())
640            .unwrap_or(16);
641        let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
642        let total_pool_blocks = max_seqs * max_blocks_per_seq;
643
644        // Lazy-allocate the shared paged pools on the FIRST paged
645        // ensure_kv call. Pools are big — for Llama-8B (8 kv_heads,
646        // head_dim=128) at 16 seqs × 256 blocks × 16 slots = 65536 KV
647        // slots: 65536 * 8 * 128 * 4 = 256 MB per layer × 32 layers
648        // = 8 GB total. Sized this large only because `max_seqs=16`
649        // is the default; lower it via env to shrink.
650        if paged && self.paged_pools.is_none() {
651            let mut pools = Vec::with_capacity(self.cfg.num_layers);
652            for _ in 0..self.cfg.num_layers {
653                let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
654                pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
655            }
656            self.paged_pools = Some(pools);
657            self.paged_block_alloc = Some(std::sync::Mutex::new(
658                crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
659            ));
660        }
661        // Phase 4b: ensure batched-dispatch scratch is allocated whenever
662        // paged is on. Idempotent — re-init is a no-op if already
663        // sized. Has to live outside the `paged_pools.is_none()` branch
664        // because `ensure_scratch` may have replaced the scratch struct
665        // since the pools were first allocated.
666        if paged {
667            self.scratch
668                .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
669        }
670
671        // Try pool first — reused buffers have stable device pointers,
672        // so a captured decode graph can be replayed for this request too.
673        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
674            (0..self.cfg.num_layers)
675                .map(|_| {
676                    if paged {
677                        // Paged mode: cache holds metadata only. K/V
678                        // are 1-element placeholders (allocated cheaply
679                        // since Backend::alloc requires a non-zero
680                        // size on most backends). The real data lives
681                        // in `self.paged_pools[li].{k,v}`.
682                        let mut block_table = B::alloc_u32(max_blocks_per_seq);
683                        let mut context_lens = B::alloc_u32(1);
684                        let mut bt_ctx = B::new_context();
685                        B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
686                        B::sync(&mut bt_ctx);
687                        KvCache {
688                            k: B::alloc(1),
689                            v: B::alloc(1),
690                            len: 0,
691                            capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
692                            num_kv_heads: nkv,
693                            head_dim: hd,
694                            block_size: PAGED_BLOCK_SIZE,
695                            block_table: Some(block_table),
696                            context_lens: Some(context_lens),
697                            paged_block_indices: Vec::new(),
698                        }
699                    } else {
700                        KvCache {
701                            k: B::alloc(nkv * max * hd),
702                            v: B::alloc(nkv * max * hd),
703                            len: 0,
704                            capacity: max,
705                            num_kv_heads: nkv,
706                            head_dim: hd,
707                            block_size: 0,
708                            block_table: None,
709                            context_lens: None,
710                            paged_block_indices: Vec::new(),
711                        }
712                    }
713                })
714                .collect()
715        });
716
717        // Allocate physical blocks for THIS cache_id from the shared
718        // pool. We allocate all `max_blocks_per_seq` upfront for
719        // simplicity (matches contig's "pre-alloc to capacity"
720        // semantics); a smarter Phase 4b can grow on demand to save
721        // pool occupancy.
722        if paged {
723            let alloc_arc = self
724                .paged_block_alloc
725                .as_ref()
726                .expect("paged_block_alloc must be initialised when paged=true");
727            // Recover from a previously-poisoned mutex instead of panicking
728            // (poison just means a prior holder panicked; the BlockAllocator
729            // state is still intact since allocate_n is fail-safe).
730            let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
731            let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
732                Ok(idx) => idx,
733                Err(e) => {
734                    // Pool exhaustion is a back-pressure signal, not a crash.
735                    // Drop the lock, return the cache to the free pool, and
736                    // bail before inserting it into kv_caches. The downstream
737                    // call will then fail with a clean per-request error
738                    // ("ensure_kv must be called before ...") instead of
739                    // dragging every other in-flight request down with it.
740                    drop(alloc);
741                    self.kv_free_pool.push(caches);
742                    eprintln!(
743                        "[ferrum] paged KV pool exhausted on ensure_kv for \
744                         cache_id={cache_id:?}: {e}. Increase \
745                         FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
746                         throttle concurrent requests.",
747                    );
748                    return;
749                }
750            };
751            // Write the block table to each layer's cache. All layers
752            // share the same logical→physical mapping for this seq.
753            // Also stash the host-side index list so release_kv can
754            // return them to the allocator without a device readback.
755            let mut padded = block_indices.clone();
756            padded.resize(max_blocks_per_seq, 0);
757            let mut ctx_tmp = B::new_context();
758            for c in caches.iter_mut() {
759                if let Some(bt) = c.block_table.as_mut() {
760                    B::write_u32(&mut ctx_tmp, bt, &padded);
761                }
762                c.paged_block_indices = block_indices.clone();
763            }
764            B::sync(&mut ctx_tmp);
765        }
766
767        // Reset logical length; buffers stay. No need to zero the memory —
768        // the kv_cache_append writes new K/V in place, and attention only
769        // reads up to `cache_len`.
770        for c in caches.iter_mut() {
771            c.len = 0;
772            if let Some(cl) = c.context_lens.as_mut() {
773                let mut ctx_tmp = B::new_context();
774                B::write_u32(&mut ctx_tmp, cl, &[0u32]);
775                B::sync(&mut ctx_tmp);
776            }
777        }
778        self.kv_caches.insert(cache_id.to_string(), caches);
779    }
780
781    /// Run one transformer layer. Mutates `residual` in place.
782    ///
783    /// `pos_offset` is the absolute position of token 0 in this batch
784    /// (decode: `pos`; prefill: 0). `tokens` is the batch size.
785    #[allow(clippy::too_many_arguments)]
786    pub(crate) fn forward_layer(
787        &mut self,
788        ctx: &mut B::Context,
789        li: usize,
790        cache_id: &str,
791        residual: &mut B::Buffer,
792        pos_offset: usize,
793        tokens: usize,
794    ) {
795        let layer = &self.layers[li];
796        let cfg = &self.cfg;
797        let h = cfg.hidden_size;
798        let nh = cfg.num_heads;
799        let nkv = cfg.num_kv_heads;
800        let hd = cfg.head_dim;
801        let im = cfg.intermediate_size;
802        let eps = cfg.rms_norm_eps;
803        let q_dim = nh * hd;
804        let kv_dim = nkv * hd;
805
806        // 1. Input RMSNorm
807        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
808            B::sync(ctx);
809            Some(std::time::Instant::now())
810        } else {
811            None
812        };
813        B::rms_norm(
814            ctx,
815            residual,
816            &layer.input_ln_w,
817            eps,
818            &mut self.scratch.norm_out,
819            tokens,
820            h,
821        );
822        if let Some(t0) = _t0 {
823            B::sync(ctx);
824            NORM_TIME_US.fetch_add(
825                t0.elapsed().as_micros() as u64,
826                std::sync::atomic::Ordering::Relaxed,
827            );
828            NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
829        }
830
831        // 2. Fused QKV projection (Linear dispatches to Dense/GPTQ/AWQ/GGUF)
832        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
833            B::sync(ctx);
834            Some(std::time::Instant::now())
835        } else {
836            None
837        };
838        layer.qkv_proj.forward(
839            ctx,
840            &self.scratch.norm_out,
841            &mut self.scratch.qkv_out,
842            tokens,
843        );
844        if let Some(t0) = _t0 {
845            B::sync(ctx);
846            MATMUL_TIME_US.fetch_add(
847                t0.elapsed().as_micros() as u64,
848                std::sync::atomic::Ordering::Relaxed,
849            );
850            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
851        }
852
853        // 3-5. Fused split-QKV + QK-norm + RoPE + cache-write.
854        //
855        // Single Metal dispatch replaces the (split_qkv → 3× qk_norm_rope
856        // → kv_cache_append_head_major) five-launch chain on the decode
857        // hot path. Reads qkv_out once, writes Q to head-major scratch
858        // and K/V straight into the pre-allocated KV cache slot at
859        // `cache_len + tok`. Saves 4 dispatches per layer when the
860        // backend implements the fused kernel; CPU and other backends
861        // keep using the unfused chain via the Unsupported fallbacks.
862        //
863        // qk_mode: 1 = norm + RoPE (Qwen3); 2 = RoPE only (Llama).
864        // V always passes apply_norm=0.
865        let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
866        let dummy = &layer.input_ln_w;
867        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
868        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
869
870        // Grab the per-layer KV cache up front so the deepest fusion can
871        // write K/V straight into it.
872        //
873        // Paged mode: also need this layer's shared pool buffers
874        // (self.paged_pools[li]). The pool is a separate field from
875        // kv_caches, so we take a raw pointer to its (k, v) here while
876        // we still hold &mut self, then deref via unsafe inside the
877        // paged dispatch below. Safety: paged_pools is allocated once
878        // and never resized; we don't touch self.paged_pools while the
879        // pointer is in use.
880        let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
881            if let Some(pools) = self.paged_pools.as_mut() {
882                let pool = &mut pools[li];
883                Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
884            } else {
885                None
886            };
887        let caches = self
888            .kv_caches
889            .get_mut(cache_id)
890            .expect("ensure_kv must be called before forward_layer");
891        let cache = &mut caches[li];
892        let cache_len_before = cache.len;
893        let cache_capacity = cache.capacity;
894
895        // Defense in depth: refuse to write past the KV buffer. The
896        // graceful path is the caller pre-checking via `kv_capacity()`
897        // and either compacting or refusing the request; this panic only
898        // fires when that contract is broken (and silent overflow would
899        // otherwise corrupt the cache + adjacent allocations).
900        if cache_len_before + tokens > cache_capacity {
901            panic!(
902                "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
903                cache_len_before + tokens
904            );
905        }
906
907        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
908            B::sync(ctx);
909            Some(std::time::Instant::now())
910        } else {
911            None
912        };
913        // Paged-KV path: when the cache was allocated with paged
914        // metadata (`block_size > 0`), use the paged write kernel
915        // which fans out into the block pool via `block_table`.
916        // Falls back to contiguous if Backend doesn't implement it.
917        let used_qkv_into_cache = if cache.block_size > 0 {
918            let bt = cache
919                .block_table
920                .as_ref()
921                .expect("paged cache missing block_table");
922            let num_blocks_per_seq = cache.capacity / cache.block_size;
923            // Paged mode: K/V live in the shared pool, not cache.k/.v.
924            let (pool_k_ptr, pool_v_ptr) =
925                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
926            // SAFETY: paged_pools is allocated once and never resized;
927            // we do not touch self.paged_pools concurrently.
928            let pool_k = unsafe { &mut *pool_k_ptr };
929            let pool_v = unsafe { &mut *pool_v_ptr };
930            B::split_qkv_norm_rope_into_paged_cache(
931                ctx,
932                &self.scratch.qkv_out,
933                0, // qkv_byte_offset: single-seq dispatch reads from start
934                q_norm_w,
935                k_norm_w,
936                &self.rope.cos,
937                &self.rope.sin,
938                &mut self.scratch.q_head_major,
939                0, // q_out_byte_offset: writes to start of head-major scratch
940                pool_k,
941                pool_v,
942                bt,
943                tokens,
944                nh,
945                nkv,
946                hd,
947                pos_offset,
948                eps,
949                qk_mode,
950                cache_len_before,
951                cache.block_size,
952                num_blocks_per_seq,
953            )
954            .is_ok()
955        } else {
956            B::split_qkv_norm_rope_into_cache(
957                ctx,
958                &self.scratch.qkv_out,
959                q_norm_w,
960                k_norm_w,
961                &self.rope.cos,
962                &self.rope.sin,
963                &mut self.scratch.q_head_major,
964                &mut cache.k,
965                &mut cache.v,
966                tokens,
967                nh,
968                nkv,
969                hd,
970                pos_offset,
971                eps,
972                qk_mode,
973                cache_len_before,
974                cache_capacity,
975            )
976            .is_ok()
977        };
978        if !used_qkv_into_cache {
979            // Fallback 1: fused split-QKV-norm-rope to head-major scratch
980            // (PR #47 path).
981            let used_fused_qkv = B::split_qkv_norm_rope(
982                ctx,
983                &self.scratch.qkv_out,
984                q_norm_w,
985                k_norm_w,
986                &self.rope.cos,
987                &self.rope.sin,
988                &mut self.scratch.q_head_major,
989                &mut self.scratch.k_head_major,
990                &mut self.scratch.v_head_major,
991                tokens,
992                nh,
993                nkv,
994                hd,
995                pos_offset,
996                eps,
997                qk_mode,
998            )
999            .is_ok();
1000            if !used_fused_qkv {
1001                // Fallback 2: original four-launch chain.
1002                B::split_qkv(
1003                    ctx,
1004                    &self.scratch.qkv_out,
1005                    &mut self.scratch.q_buf,
1006                    &mut self.scratch.k_buf,
1007                    &mut self.scratch.v_buf,
1008                    tokens,
1009                    q_dim,
1010                    kv_dim,
1011                );
1012                B::qk_norm_rope(
1013                    ctx,
1014                    &self.scratch.q_buf,
1015                    q_norm_w,
1016                    &self.rope.cos,
1017                    &self.rope.sin,
1018                    &mut self.scratch.q_head_major,
1019                    tokens,
1020                    nh,
1021                    hd,
1022                    pos_offset,
1023                    eps,
1024                    qk_mode,
1025                );
1026                B::qk_norm_rope(
1027                    ctx,
1028                    &self.scratch.k_buf,
1029                    k_norm_w,
1030                    &self.rope.cos,
1031                    &self.rope.sin,
1032                    &mut self.scratch.k_head_major,
1033                    tokens,
1034                    nkv,
1035                    hd,
1036                    pos_offset,
1037                    eps,
1038                    qk_mode,
1039                );
1040                B::qk_norm_rope(
1041                    ctx,
1042                    &self.scratch.v_buf,
1043                    dummy,
1044                    &self.rope.cos,
1045                    &self.rope.sin,
1046                    &mut self.scratch.v_head_major,
1047                    tokens,
1048                    nkv,
1049                    hd,
1050                    pos_offset,
1051                    eps,
1052                    0,
1053                );
1054            }
1055            B::kv_cache_append_head_major(
1056                ctx,
1057                &mut cache.k,
1058                &mut cache.v,
1059                cache.len,
1060                cache.capacity,
1061                &self.scratch.k_head_major,
1062                &self.scratch.v_head_major,
1063                tokens,
1064                nkv,
1065                hd,
1066            );
1067        }
1068        if let Some(t0) = _t0 {
1069            B::sync(ctx);
1070            QKR_TIME_US.fetch_add(
1071                t0.elapsed().as_micros() as u64,
1072                std::sync::atomic::Ordering::Relaxed,
1073            );
1074            QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1075        }
1076        cache.len += tokens;
1077        let kv_len = cache.len;
1078        let kv_stride = cache.capacity;
1079
1080        // 6. Flash attention.
1081        //    Paged path: when the cache uses block layout, dispatch the
1082        //    paged_decode_attention kernel; for q_len > 1 (prefill),
1083        //    iterate token-by-token (kernel only handles q_len=1 right
1084        //    now — Phase 4 will add a paged Q-tiled path).
1085        //    Contiguous path: existing flash_attention dispatch.
1086        let _attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1087            B::sync(ctx);
1088            Some(std::time::Instant::now())
1089        } else {
1090            None
1091        };
1092        if cache.block_size > 0 {
1093            let bt = cache
1094                .block_table
1095                .as_ref()
1096                .expect("paged cache missing block_table");
1097            let cl_buf = cache
1098                .context_lens
1099                .as_mut()
1100                .expect("paged cache missing context_lens");
1101            let num_blocks_per_seq = cache.capacity / cache.block_size;
1102            // Paged mode: K/V come from the shared pool.
1103            let (pool_k_ptr, pool_v_ptr) =
1104                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1105            // SAFETY: same as the write-side above; pool buffers are
1106            // allocated-once and never moved while we hold the pointer.
1107            let pool_k = unsafe { &*pool_k_ptr };
1108            let pool_v = unsafe { &*pool_v_ptr };
1109            // Single dispatch handles both decode (q_len=1) and causal
1110            // prefill (q_len>1). The kernel computes per-token causal
1111            // limit as `context_lens[seq] - (q_len - 1 - q_token_idx)`,
1112            // so we set context_lens to the FINAL kv_len after this
1113            // batch's writes.
1114            let final_kv_len = cache.len as u32;
1115            B::write_u32(ctx, cl_buf, &[final_kv_len]);
1116            B::paged_decode_attention(
1117                ctx,
1118                &self.scratch.q_head_major,
1119                pool_k,
1120                pool_v,
1121                &mut self.scratch.attn_head_major_out,
1122                bt,
1123                cl_buf,
1124                1, // num_seqs (single-seq dispatch; multi-seq is fan-in via forward_layer_batched, Phase 4b)
1125                nh,
1126                nkv,
1127                hd,
1128                cache.block_size,
1129                num_blocks_per_seq,
1130                tokens, // q_len
1131            )
1132            .expect("paged_decode_attention");
1133        } else {
1134            //    `causal` is always true for decoder-only LLMs — every query must
1135            //    mask out future tokens. Sliding-window models (Mistral v0.1) narrow
1136            //    the lower bound via `sliding_window`.
1137            let attn_cfg = ferrum_kernels::backend::AttnConfig {
1138                num_heads: nh,
1139                num_kv_heads: nkv,
1140                head_dim: hd,
1141                causal: true,
1142                scale: 1.0 / (hd as f32).sqrt(),
1143                kv_seq_stride: kv_stride,
1144                sliding_window: cfg.sliding_window,
1145            };
1146            B::flash_attention(
1147                ctx,
1148                &self.scratch.q_head_major,
1149                &cache.k,
1150                &cache.v,
1151                &mut self.scratch.attn_head_major_out,
1152                1,
1153                tokens,
1154                kv_len,
1155                pos_offset,
1156                &attn_cfg,
1157            );
1158        }
1159        if let Some(t0) = _attn_t0 {
1160            B::sync(ctx);
1161            ATTN_TIME_US.fetch_add(
1162                t0.elapsed().as_micros() as u64,
1163                std::sync::atomic::Ordering::Relaxed,
1164            );
1165            ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1166        }
1167
1168        // 7. Untranspose head-major → token-major for O-proj input.
1169        //
1170        // For tokens=1 the head-major and token-major layouts collapse
1171        // to the same flat [heads * head_dim] vector, so the dispatch is
1172        // an identity memcpy — skip it and point o_proj at the
1173        // head-major buffer directly.
1174        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1175            B::sync(ctx);
1176            Some(std::time::Instant::now())
1177        } else {
1178            None
1179        };
1180        let attn_token_major = if tokens == 1 {
1181            &self.scratch.attn_head_major_out
1182        } else {
1183            B::transpose_head_to_token(
1184                ctx,
1185                &self.scratch.attn_head_major_out,
1186                &mut self.scratch.attn_flat,
1187                tokens,
1188                nh,
1189                hd,
1190            );
1191            &self.scratch.attn_flat
1192        };
1193        if let Some(t0) = _t0 {
1194            B::sync(ctx);
1195            OTHER_TIME_US.fetch_add(
1196                t0.elapsed().as_micros() as u64,
1197                std::sync::atomic::Ordering::Relaxed,
1198            );
1199            OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1200        }
1201
1202        // 8. O projection
1203        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1204            B::sync(ctx);
1205            Some(std::time::Instant::now())
1206        } else {
1207            None
1208        };
1209        layer
1210            .o_proj
1211            .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1212        if let Some(t0) = _t0 {
1213            B::sync(ctx);
1214            MATMUL_TIME_US.fetch_add(
1215                t0.elapsed().as_micros() as u64,
1216                std::sync::atomic::Ordering::Relaxed,
1217            );
1218            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1219        }
1220
1221        // 9. Fused residual-add + post-attention RMSNorm.
1222        //    Writes the new residual back into `residual` and the normed
1223        //    value into `norm_out`.
1224        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1225            B::sync(ctx);
1226            Some(std::time::Instant::now())
1227        } else {
1228            None
1229        };
1230        B::fused_add_rms_norm(
1231            ctx,
1232            residual,
1233            &self.scratch.o_proj_out,
1234            &layer.post_ln_w,
1235            eps,
1236            &mut self.scratch.norm_out,
1237            tokens,
1238            h,
1239        );
1240        if let Some(t0) = _t0 {
1241            B::sync(ctx);
1242            NORM_TIME_US.fetch_add(
1243                t0.elapsed().as_micros() as u64,
1244                std::sync::atomic::Ordering::Relaxed,
1245            );
1246            NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1247        }
1248
1249        // 10. Fused gate+up projection
1250        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1251            B::sync(ctx);
1252            Some(std::time::Instant::now())
1253        } else {
1254            None
1255        };
1256        layer.gate_up_proj.forward(
1257            ctx,
1258            &self.scratch.norm_out,
1259            &mut self.scratch.gate_up_out,
1260            tokens,
1261        );
1262        if let Some(t0) = _t0 {
1263            B::sync(ctx);
1264            MATMUL_TIME_US.fetch_add(
1265                t0.elapsed().as_micros() as u64,
1266                std::sync::atomic::Ordering::Relaxed,
1267            );
1268            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1269        }
1270
1271        // 11. SwiGLU: silu(gate) * up
1272        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1273            B::sync(ctx);
1274            Some(std::time::Instant::now())
1275        } else {
1276            None
1277        };
1278        B::fused_silu_mul_split(
1279            ctx,
1280            &self.scratch.gate_up_out,
1281            &mut self.scratch.silu_out,
1282            tokens,
1283            im,
1284        );
1285        if let Some(t0) = _t0 {
1286            B::sync(ctx);
1287            OTHER_TIME_US.fetch_add(
1288                t0.elapsed().as_micros() as u64,
1289                std::sync::atomic::Ordering::Relaxed,
1290            );
1291            OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1292        }
1293
1294        // 12. Down projection
1295        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1296            B::sync(ctx);
1297            Some(std::time::Instant::now())
1298        } else {
1299            None
1300        };
1301        layer.down_proj.forward(
1302            ctx,
1303            &self.scratch.silu_out,
1304            &mut self.scratch.mlp_out,
1305            tokens,
1306        );
1307        if let Some(t0) = _t0 {
1308            B::sync(ctx);
1309            MATMUL_TIME_US.fetch_add(
1310                t0.elapsed().as_micros() as u64,
1311                std::sync::atomic::Ordering::Relaxed,
1312            );
1313            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1314        }
1315
1316        // 13. Final residual add
1317        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1318            B::sync(ctx);
1319            Some(std::time::Instant::now())
1320        } else {
1321            None
1322        };
1323        B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1324        if let Some(t0) = _t0 {
1325            B::sync(ctx);
1326            OTHER_TIME_US.fetch_add(
1327                t0.elapsed().as_micros() as u64,
1328                std::sync::atomic::Ordering::Relaxed,
1329            );
1330            OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1331        }
1332    }
1333
1334    /// Multi-position decode-verify: run one forward pass over `tokens`
1335    /// starting at the cache's current end position, write their K/V
1336    /// into the KV cache, and return logits for ALL `tokens.len()`
1337    /// positions as a flat `Vec<f32>` of length `seq_len * vocab_size`.
1338    ///
1339    /// Used by speculative decoding: target receives
1340    /// `[last_token, draft_0, ..., draft_{N-1}]` (N+1 inputs) and produces
1341    /// N+1 logit rows in a single forward instead of N+1 sequential
1342    /// decode() calls. Positions are implicit — the model looks up
1343    /// `pos_offset = cache.len` the same way prefill_internal does, so
1344    /// chunked prefill semantics carry over for free.
1345    pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1346        let seq_len = tokens.len();
1347        assert!(seq_len > 0, "forward_verify called with empty tokens");
1348        self.ensure_scratch(seq_len);
1349        self.ensure_kv(cache_id);
1350
1351        let h = self.cfg.hidden_size;
1352        let vocab = self.cfg.vocab_size;
1353
1354        let pos_offset = self
1355            .kv_caches
1356            .get(cache_id)
1357            .and_then(|layers| layers.first())
1358            .map(|c| c.len)
1359            .unwrap_or(0);
1360
1361        let mut ctx = B::new_context();
1362        let mut residual = self
1363            .scratch
1364            .residual
1365            .take()
1366            .expect("scratch residual missing (previous call didn't restore)");
1367
1368        let embed = self
1369            .embed
1370            .as_ref()
1371            .expect("forward_verify called on backbone-only model (no embed)");
1372        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1373
1374        for li in 0..self.cfg.num_layers {
1375            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1376        }
1377
1378        // RMSNorm on ALL seq_len positions (prefill_internal only norms
1379        // the last one; verify needs the full grid).
1380        B::rms_norm(
1381            &mut ctx,
1382            &residual,
1383            &self.final_norm_w,
1384            self.cfg.rms_norm_eps,
1385            &mut self.scratch.norm_out,
1386            seq_len,
1387            h,
1388        );
1389
1390        // LM head applied to all positions → `seq_len * vocab` logits.
1391        // Reuses the existing `batch_logits` scratch (sized max_tokens *
1392        // vocab) so no extra allocation.
1393        let lm_head = self
1394            .lm_head
1395            .as_ref()
1396            .expect("forward_verify called on backbone-only model (no lm_head)");
1397        lm_head.forward(
1398            &mut ctx,
1399            &self.scratch.norm_out,
1400            &mut self.scratch.batch_logits,
1401            seq_len,
1402        );
1403
1404        B::sync(&mut ctx);
1405        self.scratch.residual = Some(residual);
1406        B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1407    }
1408
1409    /// Prefill: process `tokens` prompt tokens in a single batch, return
1410    /// `[vocab_size]` logits for the last position.
1411    ///
1412    /// Supports incremental prefill: if the KV cache for `cache_id` already
1413    /// contains earlier tokens, the new chunk's positions are computed as
1414    /// `[kv_len, kv_len + tokens.len())` so RoPE and causal masking stay
1415    /// aligned. Used by the engine's chunked-prefill path.
1416    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1417        let seq_len = tokens.len();
1418        assert!(seq_len > 0, "prefill called with empty token list");
1419        self.ensure_scratch(seq_len);
1420        self.ensure_kv(cache_id);
1421
1422        // Starting position for this chunk — 0 for a fresh prefill, kv_len
1423        // for the second+ chunk of a split prefill.
1424        let pos_offset = self
1425            .kv_caches
1426            .get(cache_id)
1427            .and_then(|layers| layers.first())
1428            .map(|c| c.len)
1429            .unwrap_or(0);
1430
1431        let h = self.cfg.hidden_size;
1432        let vocab = self.cfg.vocab_size;
1433        let mut ctx = B::new_context();
1434
1435        // Move `residual` out of `scratch` to work around the borrow checker:
1436        // `forward_layer` re-borrows `&mut self` to reach `self.layers` /
1437        // `self.kv_caches`, which would conflict with an outstanding
1438        // `&mut self.scratch.residual`. Use Option::take to move it out
1439        // (no placeholder alloc → no transient cuMemFreeAsync that could
1440        // corrupt stream pool state after graph ops on Blackwell).
1441        let mut residual = self
1442            .scratch
1443            .residual
1444            .take()
1445            .expect("scratch residual missing (previous call didn't restore)");
1446        let embed = self
1447            .embed
1448            .as_ref()
1449            .expect("prefill_internal called on backbone-only model (no embed)");
1450        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1451
1452        let prefill_profile = std::env::var("FERRUM_PREFILL_OP_PROFILE").is_ok();
1453        let prefill_t0 = if prefill_profile {
1454            B::sync(&mut ctx);
1455            Some(std::time::Instant::now())
1456        } else {
1457            None
1458        };
1459
1460        for li in 0..self.cfg.num_layers {
1461            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1462        }
1463
1464        if let Some(t0) = prefill_t0 {
1465            B::sync(&mut ctx);
1466            let total_us = t0.elapsed().as_micros() as u64;
1467            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1468            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1469            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1470            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1471            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1472            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1473            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1474            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1475            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1476            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1477            eprintln!(
1478                "[prefill-profile] tokens={} layers total={} ms",
1479                seq_len,
1480                total_us / 1000
1481            );
1482            let bucket = |label: &str, n: u64, us: u64| {
1483                if n > 0 {
1484                    eprintln!(
1485                        "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
1486                        n,
1487                        us / 1000,
1488                        us / n
1489                    );
1490                }
1491            };
1492            bucket("flash_attn", attn_n, attn_us);
1493            bucket("qk_norm_rope", qkr_n, qkr_us);
1494            bucket("matmuls", mm_n, mm_us);
1495            bucket("norms", norm_n, norm_us);
1496            bucket("other", other_n, other_us);
1497        }
1498
1499        // Take the last token's hidden state: residual[(seq_len-1)*h .. seq_len*h]
1500        B::copy_slice(
1501            &mut ctx,
1502            &residual,
1503            (seq_len - 1) * h,
1504            &mut self.scratch.last_hidden,
1505            0,
1506            h,
1507        );
1508
1509        // Final RMSNorm on the last hidden.
1510        B::rms_norm(
1511            &mut ctx,
1512            &self.scratch.last_hidden,
1513            &self.final_norm_w,
1514            self.cfg.rms_norm_eps,
1515            &mut self.scratch.last_normed,
1516            1,
1517            h,
1518        );
1519
1520        // LM head (m=1 — triggers GEMV on MetalBackend).
1521        let lm_head = self
1522            .lm_head
1523            .as_ref()
1524            .expect("prefill_internal called on backbone-only model (no lm_head)");
1525        lm_head.forward(
1526            &mut ctx,
1527            &self.scratch.last_normed,
1528            &mut self.scratch.logits,
1529            1,
1530        );
1531
1532        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
1533        // buffer's CPU pointer without flushing the command buffer, so the
1534        // GPU must complete all pending work first or we read stale/random
1535        // data. CUDA's to_vec does an internal stream.synchronize, making
1536        // the call redundant there (~50µs/step cost), but correctness on
1537        // Metal requires the explicit flush here.
1538        B::sync(&mut ctx);
1539
1540        // Restore residual into scratch for reuse on the next call.
1541        self.scratch.residual = Some(residual);
1542
1543        B::to_vec(&self.scratch.logits, vocab)
1544    }
1545
1546    /// Decode: process 1 token at position `pos`, return `[vocab_size]` logits.
1547    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1548        self.ensure_scratch(1);
1549        self.ensure_kv(cache_id);
1550
1551        let h = self.cfg.hidden_size;
1552        let vocab = self.cfg.vocab_size;
1553
1554        // Context creation is cheap (CUDA reuses the process-global stream).
1555        // The captured graph lives in a process-global slot, not on ctx.
1556        let mut ctx = B::new_context();
1557
1558        // Graph capture is opt-in via FERRUM_CUDA_GRAPH=1. Replay is currently
1559        // single-request-only on Blackwell + CUDA 12.8 (see
1560        // docs/phase-e-cuda-status.md). In pure eager mode, we skip the
1561        // per-step device-state memcpy_htod trio entirely.
1562        const GRAPH_WARMUP: usize = 3;
1563        let graph_enabled = std::env::var("FERRUM_CUDA_GRAPH").is_ok();
1564
1565        if graph_enabled {
1566            // Refresh device-side dynamic state (token/pos/kv_len) before
1567            // replay — captured graph reads these from device buffers.
1568            B::set_decode_state(&mut ctx, token, pos);
1569
1570            // Fast path: graph replay (if available).
1571            match B::replay_last_graph(&mut ctx) {
1572                Ok(true) => {
1573                    B::sync(&mut ctx);
1574                    return B::to_vec(&self.scratch.logits, vocab);
1575                }
1576                Ok(false) => { /* no graph yet, fall through to eager */ }
1577                Err(_) => { /* backend error or unsupported, eager */ }
1578            }
1579        }
1580
1581        let should_capture =
1582            graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
1583
1584        if should_capture {
1585            B::set_dev_state_mode(&mut ctx, true);
1586            if B::begin_graph_capture(&mut ctx).is_err() {
1587                self.graph_capture_failed = true;
1588                B::set_dev_state_mode(&mut ctx, false);
1589            }
1590        }
1591
1592        // Eager forward (records into graph if capture is active).
1593        // mem::replace needs a placeholder. B::alloc(0) was our choice but
1594        // cuMemAllocFromPoolAsync(stream, 0) can return CUDA_ERROR_INVALID_VALUE
1595        // on Blackwell after graph replay corrupts the pool state. Size-1 is
1596        // always valid and costs 2 bytes of transient VRAM per decode step.
1597        let mut residual = self
1598            .scratch
1599            .residual
1600            .take()
1601            .expect("scratch residual missing (previous call didn't restore)");
1602        let embed = self
1603            .embed
1604            .as_ref()
1605            .expect("decode_internal called on backbone-only model (no embed)");
1606        B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
1607
1608        // Per-layer wall-time profile (env-gated, off by default — adds
1609        // a B::sync between layers which serializes the pipeline). Helps
1610        // localise non-matmul bottlenecks during perf work.
1611        let layer_profile = std::env::var("FERRUM_DECODE_LAYER_PROFILE").is_ok();
1612        let mut layer_times = if layer_profile {
1613            Some(Vec::with_capacity(self.cfg.num_layers))
1614        } else {
1615            None
1616        };
1617
1618        for li in 0..self.cfg.num_layers {
1619            if layer_profile {
1620                let t0 = std::time::Instant::now();
1621                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1622                B::sync(&mut ctx);
1623                let elapsed_us = t0.elapsed().as_micros() as u64;
1624                if let Some(v) = layer_times.as_mut() {
1625                    v.push(elapsed_us);
1626                }
1627            } else {
1628                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1629            }
1630        }
1631        if let Some(times) = layer_times.take() {
1632            let sum: u64 = times.iter().sum();
1633            let avg = sum / times.len() as u64;
1634            let mn = *times.iter().min().unwrap_or(&0);
1635            let mx = *times.iter().max().unwrap_or(&0);
1636            eprintln!(
1637                "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
1638                times.len(),
1639                sum / 1000,
1640                avg,
1641                mn,
1642                mx
1643            );
1644            for (i, t) in times.iter().enumerate() {
1645                eprint!("L{i}={}ms ", t / 1000);
1646                if (i + 1) % 6 == 0 {
1647                    eprintln!();
1648                }
1649            }
1650            eprintln!();
1651            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1652            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1653            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1654            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1655            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1656            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1657            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1658            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1659            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1660            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1661            eprintln!(
1662                "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
1663                attn_n,
1664                attn_us / 1000,
1665                if attn_n > 0 { attn_us / attn_n } else { 0 }
1666            );
1667            eprintln!(
1668                "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
1669                qkr_n,
1670                qkr_us / 1000,
1671                if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
1672            );
1673            eprintln!(
1674                "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
1675                mm_n,
1676                mm_us / 1000,
1677                if mm_n > 0 { mm_us / mm_n } else { 0 }
1678            );
1679            eprintln!(
1680                "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
1681                norm_n,
1682                norm_us / 1000,
1683                if norm_n > 0 { norm_us / norm_n } else { 0 }
1684            );
1685            eprintln!(
1686                "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
1687                other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
1688            );
1689        }
1690
1691        B::rms_norm(
1692            &mut ctx,
1693            &residual,
1694            &self.final_norm_w,
1695            self.cfg.rms_norm_eps,
1696            &mut self.scratch.last_normed,
1697            1,
1698            h,
1699        );
1700
1701        let lm_head = self
1702            .lm_head
1703            .as_ref()
1704            .expect("decode_internal called on backbone-only model (no lm_head)");
1705        lm_head.forward(
1706            &mut ctx,
1707            &self.scratch.last_normed,
1708            &mut self.scratch.logits,
1709            1,
1710        );
1711
1712        if should_capture && !self.graph_capture_failed {
1713            if B::end_graph_capture(&mut ctx).is_err() {
1714                self.graph_capture_failed = true;
1715            } else {
1716                // Stream capture mode RECORDS ops into the graph without
1717                // executing them. scratch.logits still holds the previous
1718                // step's value. Replay the just-captured graph once to
1719                // actually execute and produce this step's logits. Without
1720                // this, the capture step's to_vec returns stale logits,
1721                // yielding a 1-token offset in the generated sequence.
1722                if B::replay_last_graph(&mut ctx).is_err() {
1723                    self.graph_capture_failed = true;
1724                }
1725            }
1726            B::set_dev_state_mode(&mut ctx, false);
1727        } else {
1728            self.graph_warmup += 1;
1729        }
1730
1731        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
1732        // buffer's CPU pointer without flushing the command buffer, so the
1733        // GPU must complete all pending work first or we read stale/random
1734        // data. CUDA's to_vec does an internal stream.synchronize, making
1735        // the call redundant there (~50µs/step cost), but correctness on
1736        // Metal requires the explicit flush here.
1737        B::sync(&mut ctx);
1738        self.scratch.residual = Some(residual);
1739
1740        B::to_vec(&self.scratch.logits, vocab)
1741    }
1742
1743    /// Prefill with pre-computed embeddings instead of token IDs.
1744    ///
1745    /// Used by models that embed inputs outside the LLM (e.g. Qwen3-TTS
1746    /// mixes text-embedding + codec-embedding before feeding the LM).
1747    /// Skips `final_norm` + `lm_head`; returns the last position's pre-norm
1748    /// hidden state. Caller applies its own output head.
1749    ///
1750    /// `embeds` is row-major `[seq_len * hidden_size]`, f32.
1751    pub fn prefill_from_embeds(
1752        &mut self,
1753        cache_id: &str,
1754        embeds: &[f32],
1755        seq_len: usize,
1756    ) -> Vec<f32> {
1757        let h = self.cfg.hidden_size;
1758        assert_eq!(
1759            embeds.len(),
1760            seq_len * h,
1761            "embeds length {} != seq_len * hidden_size {}",
1762            embeds.len(),
1763            seq_len * h
1764        );
1765        assert!(seq_len > 0, "prefill_from_embeds called with zero length");
1766
1767        self.ensure_scratch(seq_len);
1768        self.ensure_kv(cache_id);
1769
1770        let mut ctx = B::new_context();
1771        let mut residual = self
1772            .scratch
1773            .residual
1774            .take()
1775            .expect("scratch residual missing (previous call didn't restore)");
1776
1777        // Upload embeds → residual[0 .. seq_len*h].
1778        let embed_buf = B::from_slice(embeds);
1779        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1780
1781        for li in 0..self.cfg.num_layers {
1782            self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1783        }
1784
1785        B::copy_slice(
1786            &mut ctx,
1787            &residual,
1788            (seq_len - 1) * h,
1789            &mut self.scratch.last_hidden,
1790            0,
1791            h,
1792        );
1793        B::sync(&mut ctx);
1794        self.scratch.residual = Some(residual);
1795        B::to_vec(&self.scratch.last_hidden, h)
1796    }
1797
1798    /// Decode with a single pre-computed embedding (shape `[hidden]`).
1799    /// Returns the pre-norm hidden state for the position `pos`. Caller
1800    /// applies final norm + its own output head.
1801    pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1802        let h = self.cfg.hidden_size;
1803        assert_eq!(
1804            embed.len(),
1805            h,
1806            "embed length {} != hidden_size {}",
1807            embed.len(),
1808            h
1809        );
1810
1811        self.ensure_scratch(1);
1812        self.ensure_kv(cache_id);
1813
1814        let mut ctx = B::new_context();
1815        let mut residual = self
1816            .scratch
1817            .residual
1818            .take()
1819            .expect("scratch residual missing (previous call didn't restore)");
1820
1821        let embed_buf = B::from_slice(embed);
1822        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1823
1824        for li in 0..self.cfg.num_layers {
1825            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1826        }
1827
1828        B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1829        B::sync(&mut ctx);
1830        self.scratch.residual = Some(residual);
1831        B::to_vec(&self.scratch.last_hidden, h)
1832    }
1833
1834    /// Variant of `prefill_from_embeds` that applies `final_norm` to every
1835    /// position and returns the whole `[seq_len * hidden_size]` vector.
1836    /// Accepts `pos_offset` so callers can continue an existing sequence
1837    /// (e.g. Qwen3-TTS voice-clone: one prefill for the role prefix, a
1838    /// follow-up prefill for the reference-audio ICL block, then
1839    /// autoregressive decoding — all against the same KV cache).
1840    ///
1841    /// Used by TTS where `forward_step` in the candle-based wrapper is
1842    /// expected to return **post-norm all-positions** hidden state so
1843    /// `codec_head` can be applied on candle side.
1844    pub fn prefill_all_post_norm(
1845        &mut self,
1846        cache_id: &str,
1847        embeds: &[f32],
1848        seq_len: usize,
1849        pos_offset: usize,
1850    ) -> Vec<f32> {
1851        let h = self.cfg.hidden_size;
1852        assert_eq!(
1853            embeds.len(),
1854            seq_len * h,
1855            "embeds length {} != seq_len * hidden_size {}",
1856            embeds.len(),
1857            seq_len * h
1858        );
1859        assert!(seq_len > 0);
1860
1861        self.ensure_scratch(seq_len);
1862        self.ensure_kv(cache_id);
1863
1864        let mut ctx = B::new_context();
1865        let mut residual = self
1866            .scratch
1867            .residual
1868            .take()
1869            .expect("scratch residual missing (previous call didn't restore)");
1870
1871        let embed_buf = B::from_slice(embeds);
1872        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1873
1874        for li in 0..self.cfg.num_layers {
1875            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1876        }
1877
1878        // Apply final_norm over all seq_len positions → scratch.norm_out.
1879        B::rms_norm(
1880            &mut ctx,
1881            &residual,
1882            &self.final_norm_w,
1883            self.cfg.rms_norm_eps,
1884            &mut self.scratch.norm_out,
1885            seq_len,
1886            h,
1887        );
1888        B::sync(&mut ctx);
1889        self.scratch.residual = Some(residual);
1890        B::to_vec(&self.scratch.norm_out, seq_len * h)
1891    }
1892
1893    /// Decode-side companion to `prefill_all_post_norm`. Runs a single-token
1894    /// decode step at `pos`, applies `final_norm`, and returns the post-norm
1895    /// hidden state `[hidden_size]`.
1896    pub fn decode_post_norm_from_embed(
1897        &mut self,
1898        cache_id: &str,
1899        embed: &[f32],
1900        pos: u32,
1901    ) -> Vec<f32> {
1902        let h = self.cfg.hidden_size;
1903        assert_eq!(embed.len(), h);
1904
1905        self.ensure_scratch(1);
1906        self.ensure_kv(cache_id);
1907
1908        let mut ctx = B::new_context();
1909        let mut residual = self
1910            .scratch
1911            .residual
1912            .take()
1913            .expect("scratch residual missing (previous call didn't restore)");
1914
1915        let embed_buf = B::from_slice(embed);
1916        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1917
1918        for li in 0..self.cfg.num_layers {
1919            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1920        }
1921
1922        B::rms_norm(
1923            &mut ctx,
1924            &residual,
1925            &self.final_norm_w,
1926            self.cfg.rms_norm_eps,
1927            &mut self.scratch.last_normed,
1928            1,
1929            h,
1930        );
1931        B::sync(&mut ctx);
1932        self.scratch.residual = Some(residual);
1933        B::to_vec(&self.scratch.last_normed, h)
1934    }
1935
1936    /// Batched decode: process M concurrent requests at potentially different
1937    /// positions in one forward pass. GEMM-heavy ops (qkv_proj, o_proj,
1938    /// gate_up, down) run with m=M for natural batching; rope + KV append +
1939    /// attention loop per-item (each has its own KV cache at a different
1940    /// kv_len, and potentially different pos).
1941    ///
1942    /// Returns M logit vectors in the same order as `batch`.
1943    pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1944        let m = batch.len();
1945        if m == 0 {
1946            return Vec::new();
1947        }
1948        if m == 1 {
1949            let (cid, tok, pos) = &batch[0];
1950            return vec![self.decode_internal(cid, *tok, *pos)];
1951        }
1952
1953        // Ensure all caches exist and scratch is sized for M tokens.
1954        for (cid, _, _) in batch {
1955            self.ensure_kv(cid);
1956        }
1957        self.ensure_scratch(m);
1958        // Phase 4b: when paged mode is on, ensure_kv has already
1959        // populated the batched scratch buffers (paged_batch_q etc.).
1960        // The forward path branches on `paged_pools.is_some()` inside
1961        // each layer.
1962
1963        let h = self.cfg.hidden_size;
1964        let vocab = self.cfg.vocab_size;
1965        let mut ctx = B::new_context();
1966
1967        // 0. Embed all M tokens into residual [M, H]
1968        let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1969        let mut residual = self
1970            .scratch
1971            .residual
1972            .take()
1973            .expect("scratch residual missing (previous call didn't restore)");
1974        let embed = self
1975            .embed
1976            .as_ref()
1977            .expect("decode_batch_internal called on backbone-only model (no embed)");
1978        B::embedding_lookup(&mut ctx, embed, &tokens, &mut residual, h);
1979
1980        // 1..num_layers: batched forward for each layer
1981        for li in 0..self.cfg.num_layers {
1982            self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m);
1983        }
1984
1985        // Final RMSNorm on [M, H] → norm_out [M, H]
1986        B::rms_norm(
1987            &mut ctx,
1988            &residual,
1989            &self.final_norm_w,
1990            self.cfg.rms_norm_eps,
1991            &mut self.scratch.norm_out,
1992            m,
1993            h,
1994        );
1995
1996        // LM head with m=M → batch_logits [M, vocab]
1997        let lm_head = self
1998            .lm_head
1999            .as_ref()
2000            .expect("decode_batch_internal called on backbone-only model (no lm_head)");
2001        lm_head.forward(
2002            &mut ctx,
2003            &self.scratch.norm_out,
2004            &mut self.scratch.batch_logits,
2005            m,
2006        );
2007
2008        // Sync before to_vec (Metal: no internal sync on buffer read).
2009        B::sync(&mut ctx);
2010        self.scratch.residual = Some(residual);
2011
2012        // Extract M logit vectors from the flat buffer.
2013        let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
2014        (0..m)
2015            .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
2016            .collect()
2017    }
2018
2019    /// One transformer layer over M items, GEMMs batched + per-item attention.
2020    fn forward_layer_batched_decode(
2021        &mut self,
2022        ctx: &mut B::Context,
2023        li: usize,
2024        batch: &[(String, u32, u32)],
2025        residual: &mut B::Buffer,
2026        m: usize,
2027    ) {
2028        let cfg = &self.cfg;
2029        let h = cfg.hidden_size;
2030        let nh = cfg.num_heads;
2031        let nkv = cfg.num_kv_heads;
2032        let hd = cfg.head_dim;
2033        let im = cfg.intermediate_size;
2034        let eps = cfg.rms_norm_eps;
2035        let q_dim = nh * hd;
2036        let kv_dim = nkv * hd;
2037
2038        let layer = &self.layers[li];
2039        let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
2040        let dummy_w = &layer.input_ln_w;
2041        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy_w);
2042        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy_w);
2043
2044        // 1. rms_norm [M, H]  → norm_out
2045        B::rms_norm(
2046            ctx,
2047            residual,
2048            &layer.input_ln_w,
2049            eps,
2050            &mut self.scratch.norm_out,
2051            m,
2052            h,
2053        );
2054
2055        // 2. qkv_proj (GEMM m=M): norm_out [M, H] → qkv_out [M, QKV]
2056        layer
2057            .qkv_proj
2058            .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
2059
2060        // ── Paged-KV batched path (Phase 4b) ──────────────────────────
2061        // When paged is on, we skip the contig split_qkv + per-item
2062        // qk_norm_rope + kv_append + flash_attention loop entirely.
2063        // Instead:
2064        //   1. Per item: split_qkv_norm_rope_into_paged_cache with
2065        //      qkv_byte_offset = i * qkv_stride * 4 reads item i's
2066        //      slice of qkv_out, writes K/V into the shared pool at
2067        //      its block_table-resolved position, and stores the
2068        //      RoPE'd Q at paged_batch_q[i * q_dim .. (i+1) * q_dim].
2069        //   2. Build batched block_tables [M, max_blocks_per_seq] +
2070        //      context_lens [M] host-side, write to scratch device
2071        //      buffers.
2072        //   3. Single paged_decode_attention(num_seqs=M) reads all M
2073        //      seqs' K/V via per-seq block_tables, writes to
2074        //      paged_batch_o.
2075        //   4. Per item: copy paged_batch_o[i] → attn_flat[i * q_dim].
2076        //
2077        // This is the "real" multi-seq decode — one heavy attention
2078        // dispatch covering all sequences instead of M sequential ones.
2079        if let Some(pools) = self.paged_pools.as_mut() {
2080            let pool_ptr = (
2081                &mut pools[li].0 as *mut B::Buffer,
2082                &mut pools[li].1 as *mut B::Buffer,
2083            );
2084            // SAFETY: pools allocated once; not concurrently mutated.
2085            let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
2086
2087            let qkv_stride = q_dim + 2 * kv_dim;
2088            let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
2089            let block_size = 16; // matches PAGED_BLOCK_SIZE in ensure_kv
2090
2091            // Step 1: per-item paged write. We collect cache_len + block_indices
2092            // up front for step 2. Note: this loop borrows self.kv_caches mutably
2093            // per iteration, so we extract the batched-write parameters first then
2094            // do the dispatches.
2095            let mut item_state: Vec<(u32, Vec<u32>)> = Vec::with_capacity(m);
2096            for (cache_id, _, _) in batch.iter() {
2097                let caches = self
2098                    .kv_caches
2099                    .get(cache_id)
2100                    .expect("ensure_kv must be called before forward_layer_batched");
2101                let cache = &caches[li];
2102                item_state.push((cache.len as u32, cache.paged_block_indices.clone()));
2103            }
2104
2105            // Take block_table buffer ptrs ahead of the dispatch loop —
2106            // we need both per-cache block_table (to write into) and
2107            // self.scratch.paged_batch_q (to write Q stacks into).
2108            let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
2109            let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
2110            for (i, (cache_id, _, pos)) in batch.iter().enumerate() {
2111                let pos_i = *pos as usize;
2112                let caches = self
2113                    .kv_caches
2114                    .get(cache_id)
2115                    .expect("paged batched: cache not present");
2116                let cache = &caches[li];
2117                let bt = cache
2118                    .block_table
2119                    .as_ref()
2120                    .expect("paged batched: block_table missing");
2121                let cache_len_before = cache.len;
2122                let block_table_ref = bt as *const B::Buffer;
2123                // SAFETY: bt is read-only in the dispatch; we don't
2124                // mutate self.kv_caches between this raw deref and the
2125                // call.
2126                let bt_safe: &B::Buffer = unsafe { &*block_table_ref };
2127                B::split_qkv_norm_rope_into_paged_cache(
2128                    ctx,
2129                    &self.scratch.qkv_out,
2130                    (i as u64) * qkv_stride_bytes,
2131                    q_norm_w,
2132                    k_norm_w,
2133                    &self.rope.cos,
2134                    &self.rope.sin,
2135                    self.scratch
2136                        .paged_batch_q
2137                        .as_mut()
2138                        .expect("paged_batch_q missing"),
2139                    (i as u64) * q_head_major_size_bytes,
2140                    pool_k,
2141                    pool_v,
2142                    bt_safe,
2143                    1,
2144                    nh,
2145                    nkv,
2146                    hd,
2147                    pos_i,
2148                    eps,
2149                    qk_mode,
2150                    cache_len_before,
2151                    block_size,
2152                    max_blocks_per_seq,
2153                )
2154                .expect("paged batched write");
2155            }
2156
2157            // Step 2: bump cache.len and build the stacked block_tables +
2158            // context_lens host-side, then upload to device scratch.
2159            let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
2160            let mut stacked_cl: Vec<u32> = vec![0u32; m];
2161            for (i, (cache_id, _, _)) in batch.iter().enumerate() {
2162                let caches = self
2163                    .kv_caches
2164                    .get_mut(cache_id)
2165                    .expect("paged batched: cache not present");
2166                let cache = &mut caches[li];
2167                cache.len += 1;
2168                let len = cache.len as u32;
2169                stacked_cl[i] = len;
2170                let blocks = &cache.paged_block_indices;
2171                let n_to_copy = blocks.len().min(max_blocks_per_seq);
2172                stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
2173                    .copy_from_slice(&blocks[..n_to_copy]);
2174            }
2175            let bt_buf = self
2176                .scratch
2177                .paged_batch_block_tables
2178                .as_mut()
2179                .expect("paged_batch_block_tables missing");
2180            B::write_u32(ctx, bt_buf, &stacked_bt);
2181            let cl_buf = self
2182                .scratch
2183                .paged_batch_context_lens
2184                .as_mut()
2185                .expect("paged_batch_context_lens missing");
2186            B::write_u32(ctx, cl_buf, &stacked_cl);
2187
2188            // Step 3: one batched paged_decode_attention(num_seqs=m).
2189            let bt_ptr =
2190                self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
2191            let cl_ptr =
2192                self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
2193            let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
2194            let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
2195            // SAFETY: the four scratch buffers above are not aliased
2196            // by anything else; we only deref while &mut self is held.
2197            let bt_safe = unsafe { &*bt_ptr };
2198            let cl_safe = unsafe { &*cl_ptr };
2199            let q_safe = unsafe { &*q_ptr };
2200            let o_safe = unsafe { &mut *o_ptr };
2201            B::paged_decode_attention(
2202                ctx,
2203                q_safe,
2204                pool_k,
2205                pool_v,
2206                o_safe,
2207                bt_safe,
2208                cl_safe,
2209                m,
2210                nh,
2211                nkv,
2212                hd,
2213                block_size,
2214                max_blocks_per_seq,
2215                1, // q_len
2216            )
2217            .expect("paged batched decode");
2218
2219            // Step 4: per-item copy paged_batch_o[i] → attn_flat[i * q_dim].
2220            // Both have q_dim floats per item; same head-major-equals-token-major
2221            // identity collapse used in the contig path.
2222            for i in 0..m {
2223                B::copy_slice(
2224                    ctx,
2225                    self.scratch.paged_batch_o.as_ref().unwrap(),
2226                    i * q_dim,
2227                    &mut self.scratch.attn_flat,
2228                    i * q_dim,
2229                    q_dim,
2230                );
2231            }
2232
2233            // Skip the contig split_qkv + per-item loop below.
2234            return self.forward_layer_batched_decode_post_attn(ctx, li, residual, m);
2235        }
2236
2237        // 3. split_qkv [M, QKV] → q_buf [M, Q], k_buf [M, KV], v_buf [M, KV]
2238        B::split_qkv(
2239            ctx,
2240            &self.scratch.qkv_out,
2241            &mut self.scratch.q_buf,
2242            &mut self.scratch.k_buf,
2243            &mut self.scratch.v_buf,
2244            m,
2245            q_dim,
2246            kv_dim,
2247        );
2248
2249        // 4-6. Per-item loop for rope + kv_append + attention.
2250        //      Each item has its own cache_id + pos + kv_len.
2251        for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
2252            let pos_i = *pos as usize;
2253
2254            // Extract item i's Q/K/V from batched buffers.
2255            B::copy_slice(
2256                ctx,
2257                &self.scratch.q_buf,
2258                i * q_dim,
2259                &mut self.scratch.q_single,
2260                0,
2261                q_dim,
2262            );
2263            B::copy_slice(
2264                ctx,
2265                &self.scratch.k_buf,
2266                i * kv_dim,
2267                &mut self.scratch.k_single,
2268                0,
2269                kv_dim,
2270            );
2271            B::copy_slice(
2272                ctx,
2273                &self.scratch.v_buf,
2274                i * kv_dim,
2275                &mut self.scratch.v_single,
2276                0,
2277                kv_dim,
2278            );
2279
2280            // qk_norm_rope with tokens=1, per-item pos.
2281            B::qk_norm_rope(
2282                ctx,
2283                &self.scratch.q_single,
2284                q_norm_w,
2285                &self.rope.cos,
2286                &self.rope.sin,
2287                &mut self.scratch.q_head_major_single,
2288                1,
2289                nh,
2290                hd,
2291                pos_i,
2292                eps,
2293                qk_mode,
2294            );
2295            B::qk_norm_rope(
2296                ctx,
2297                &self.scratch.k_single,
2298                k_norm_w,
2299                &self.rope.cos,
2300                &self.rope.sin,
2301                &mut self.scratch.k_head_major_single,
2302                1,
2303                nkv,
2304                hd,
2305                pos_i,
2306                eps,
2307                qk_mode,
2308            );
2309            B::qk_norm_rope(
2310                ctx,
2311                &self.scratch.v_single,
2312                dummy_w,
2313                &self.rope.cos,
2314                &self.rope.sin,
2315                &mut self.scratch.v_head_major_single,
2316                1,
2317                nkv,
2318                hd,
2319                pos_i,
2320                eps,
2321                0,
2322            );
2323
2324            // KV append + attention for item i's cache.
2325            let caches = self
2326                .kv_caches
2327                .get_mut(cache_id)
2328                .expect("ensure_kv must be called before forward_layer_batched");
2329            let cache = &mut caches[li];
2330            B::kv_cache_append_head_major(
2331                ctx,
2332                &mut cache.k,
2333                &mut cache.v,
2334                cache.len,
2335                cache.capacity,
2336                &self.scratch.k_head_major_single,
2337                &self.scratch.v_head_major_single,
2338                1,
2339                nkv,
2340                hd,
2341            );
2342            cache.len += 1;
2343            let kv_len = cache.len;
2344            let kv_stride = cache.capacity;
2345
2346            let attn_cfg = ferrum_kernels::backend::AttnConfig {
2347                num_heads: nh,
2348                num_kv_heads: nkv,
2349                head_dim: hd,
2350                causal: true,
2351                scale: 1.0 / (hd as f32).sqrt(),
2352                kv_seq_stride: kv_stride,
2353                sliding_window: cfg.sliding_window,
2354            };
2355            B::flash_attention(
2356                ctx,
2357                &self.scratch.q_head_major_single,
2358                &cache.k,
2359                &cache.v,
2360                &mut self.scratch.attn_head_major_single,
2361                1,
2362                1,
2363                kv_len,
2364                pos_i,
2365                &attn_cfg,
2366            );
2367
2368            // Untranspose head-major → token-major + inject into batched
2369            // attn_flat[M, Q]. For tokens=1 the head-major and
2370            // token-major layouts are byte-identical (both flat to
2371            // [heads * head_dim] = [q_dim] floats), so we skip the
2372            // transpose dispatch entirely and copy attn_head_major_single
2373            // straight into the per-item slot. Saves 1 dispatch per
2374            // batch-item per layer.
2375            B::copy_slice(
2376                ctx,
2377                &self.scratch.attn_head_major_single,
2378                0,
2379                &mut self.scratch.attn_flat,
2380                i * q_dim,
2381                q_dim,
2382            );
2383        }
2384
2385        self.forward_layer_batched_decode_post_attn(ctx, li, residual, m);
2386    }
2387
2388    fn forward_layer_batched_decode_post_attn(
2389        &mut self,
2390        ctx: &mut B::Context,
2391        li: usize,
2392        residual: &mut B::Buffer,
2393        m: usize,
2394    ) {
2395        let cfg = &self.cfg;
2396        let h = cfg.hidden_size;
2397        let im = cfg.intermediate_size;
2398        let eps = cfg.rms_norm_eps;
2399        let layer = &self.layers[li];
2400
2401        // 7. o_proj (GEMM m=M): attn_flat [M, Q] → o_proj_out [M, H]
2402        layer.o_proj.forward(
2403            ctx,
2404            &self.scratch.attn_flat,
2405            &mut self.scratch.o_proj_out,
2406            m,
2407        );
2408
2409        // 8. Fused residual add + post-attention RMSNorm.
2410        B::fused_add_rms_norm(
2411            ctx,
2412            residual,
2413            &self.scratch.o_proj_out,
2414            &layer.post_ln_w,
2415            eps,
2416            &mut self.scratch.norm_out,
2417            m,
2418            h,
2419        );
2420
2421        // 9. gate_up_proj (GEMM m=M)
2422        layer.gate_up_proj.forward(
2423            ctx,
2424            &self.scratch.norm_out,
2425            &mut self.scratch.gate_up_out,
2426            m,
2427        );
2428
2429        // 10. SwiGLU
2430        B::fused_silu_mul_split(
2431            ctx,
2432            &self.scratch.gate_up_out,
2433            &mut self.scratch.silu_out,
2434            m,
2435            im,
2436        );
2437
2438        // 11. down_proj (GEMM m=M)
2439        layer
2440            .down_proj
2441            .forward(ctx, &self.scratch.silu_out, &mut self.scratch.mlp_out, m);
2442
2443        // 12. Residual add
2444        B::add_inplace(ctx, residual, &self.scratch.mlp_out, m * h);
2445    }
2446}
2447
2448impl<B: Backend> DecoderOnlyLLM for LlamaFamilyModel<B> {
2449    fn config(&self) -> &LlmRuntimeConfig {
2450        &self.runtime_cfg
2451    }
2452
2453    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2454        // Eager scratch + KV cache grow + a 1-token forward warmup —
2455        // see the Qwen3MoeModel::prepare comment for the rationale.
2456        // Without the warmup forward, the first real prefill pays
2457        // Metal pipeline first-bind costs inside the timer window.
2458        self.ensure_scratch(max_tokens);
2459        self.ensure_kv(cache_id);
2460
2461        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2462        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2463        // Release via the same path as `release` so paged blocks
2464        // return to the shared allocator. Otherwise warmup leaks
2465        // 256 blocks (the full per-seq quota) into the pool.
2466        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2467            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2468                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2469                if let Some(c0) = caches.first() {
2470                    if !c0.paged_block_indices.is_empty() {
2471                        alloc.free(&c0.paged_block_indices);
2472                    }
2473                }
2474                for c in caches.iter_mut() {
2475                    c.paged_block_indices.clear();
2476                }
2477            }
2478            self.kv_free_pool.push(caches);
2479        }
2480    }
2481
2482    fn kv_capacity(&self) -> usize {
2483        // Mirror the bound `ensure_kv` will use when allocating the cache.
2484        let model_max = self.cfg.max_seq_len;
2485        const DEFAULT_KV_CAPACITY: usize = 4096;
2486        std::env::var("FERRUM_KV_CAPACITY")
2487            .ok()
2488            .and_then(|s| s.parse::<usize>().ok())
2489            .map(|cap| cap.min(model_max))
2490            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
2491    }
2492
2493    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2494        self.prefill_internal(cache_id, tokens)
2495    }
2496
2497    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2498        self.decode_internal(cache_id, token, pos)
2499    }
2500
2501    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2502        self.decode_batch_internal(batch)
2503    }
2504
2505    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2506        // Delegate to the inherent implementation on LlamaFamilyModel.
2507        LlamaFamilyModel::<B>::forward_verify(self, cache_id, tokens)
2508    }
2509
2510    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2511        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2512            for c in caches.iter_mut() {
2513                if new_len < c.len {
2514                    c.len = new_len;
2515                }
2516            }
2517        }
2518        // Captured graph expects a specific cache layout; roll it back too.
2519        let mut ctx = B::new_context();
2520        B::reset_graph(&mut ctx);
2521        self.graph_warmup = 0;
2522        self.graph_capture_failed = false;
2523    }
2524
2525    fn release(&mut self, cache_id: &str) {
2526        // Sync + drop graph BEFORE touching cache buffers. The graph was
2527        // actively running replays up to this point; destroying the graph
2528        // while the allocator pool still has in-flight references from the
2529        // graph's kernels corrupts stream state. Sync first to drain, then
2530        // destroy graph, then sync again to ensure cleanup completes.
2531        let mut ctx = B::new_context();
2532        B::sync(&mut ctx);
2533        B::reset_graph(&mut ctx);
2534        B::sync(&mut ctx);
2535        self.graph_warmup = 0;
2536        self.graph_capture_failed = false;
2537
2538        // Return the cache's buffers to the free pool instead of dropping.
2539        // Pointers stay stable for the next request's captured graph.
2540        // Paged mode: also free the cache's blocks back to the shared
2541        // allocator so other sequences can reuse them.
2542        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2543            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2544                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2545                // All caches share the same block_indices set (one per
2546                // cache_id), so freeing once via the first layer's cache
2547                // is enough.
2548                if let Some(c0) = caches.first() {
2549                    if !c0.paged_block_indices.is_empty() {
2550                        alloc.free(&c0.paged_block_indices);
2551                    }
2552                }
2553                // Clear the host-side mirror on every layer so a
2554                // free-pool reuse re-allocates fresh blocks.
2555                for c in caches.iter_mut() {
2556                    c.paged_block_indices.clear();
2557                }
2558            }
2559            self.kv_free_pool.push(caches);
2560        }
2561    }
2562
2563    fn reset(&mut self) {
2564        // Hard reset: drop all caches AND the pool, invalidate graph.
2565        let mut ctx = B::new_context();
2566        B::sync(&mut ctx);
2567        B::reset_graph(&mut ctx);
2568        B::sync(&mut ctx);
2569        self.graph_warmup = 0;
2570        self.graph_capture_failed = false;
2571        self.kv_caches.clear();
2572        self.kv_free_pool.clear();
2573    }
2574}
2575
2576fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2577    let hd = cfg.head_dim;
2578    let half = hd / 2;
2579    let max = cfg.max_seq_len;
2580    let mut cos = vec![0.0f32; max * half];
2581    let mut sin = vec![0.0f32; max * half];
2582    for pos in 0..max {
2583        for i in 0..half {
2584            let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
2585            let angle = pos as f64 * freq;
2586            cos[pos * half + i] = angle.cos() as f32;
2587            sin[pos * half + i] = angle.sin() as f32;
2588        }
2589    }
2590    RopeCache {
2591        cos: B::from_slice(&cos),
2592        sin: B::from_slice(&sin),
2593    }
2594}