Skip to main content

ferrum_models/models/
llama_family.rs

1//! Llama-family decoder model as explicit code.
2//!
3//! Covers all "standard Llama-style GQA + SwiGLU + RoPE" decoders:
4//!   - Llama / Llama-2 / Llama-3  (no QK-norm)
5//!   - Qwen2 / Qwen2.5            (no QK-norm, structurally Llama)
6//!   - Qwen3                      (QK-norm per head, larger rope_theta)
7//!   - Mistral                    (sliding-window attention — not yet
8//!                                 supported on the forward path; will
9//!                                 require an `AttnConfig.sliding_window`
10//!                                 field + shader support in Phase D)
11//!
12//! Variant differences are controlled by `LlamaFamilyConfig::has_qk_norm`
13//! and RoPE theta. Weight loading accepts both fused (`qkv_proj`,
14//! `gate_up_proj`) and split (`q_proj`+`k_proj`+`v_proj`,
15//! `gate_proj`+`up_proj`) projection layouts — the loader (e.g.
16//! `CandleVarBuilderLoader`) fuses split weights on load so model code
17//! sees a uniform `qkv_proj` / `gate_up_proj` Linear.
18
19use std::collections::HashMap;
20use std::sync::atomic::AtomicU64;
21
22use ferrum_kernels::backend::{Backend, KvCache};
23
24static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
25static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
26static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
27static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
28static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
29static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
30static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
31static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
32static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
33static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
34use ferrum_quantization::{Linear, WeightLoader};
35use ferrum_types::Result;
36
37use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
38
39/// Full Qwen3 architecture config (everything the model code needs, not just
40/// the engine-facing subset in `LlmRuntimeConfig`).
41#[derive(Clone, Debug, PartialEq)]
42pub struct LlamaFamilyConfig {
43    pub hidden_size: usize,
44    pub intermediate_size: usize,
45    pub num_heads: usize,
46    pub num_kv_heads: usize,
47    pub head_dim: usize,
48    pub num_layers: usize,
49    pub vocab_size: usize,
50    pub max_seq_len: usize,
51    pub rms_norm_eps: f32,
52    pub rope_theta: f64,
53    /// Whether the checkpoint has `q_norm` / `k_norm` per layer. All known
54    /// Qwen3 checkpoints do; some derivatives may strip them.
55    pub has_qk_norm: bool,
56    /// Sliding-window attention size. `0` disables (full causal).
57    /// Mistral v0.1 sets 4096; Mistral v0.2+ removed the limitation (0).
58    pub sliding_window: usize,
59}
60
61impl LlamaFamilyConfig {
62    pub fn to_runtime(&self) -> LlmRuntimeConfig {
63        LlmRuntimeConfig {
64            hidden_size: self.hidden_size,
65            num_layers: self.num_layers,
66            num_kv_heads: self.num_kv_heads,
67            head_dim: self.head_dim,
68            vocab_size: self.vocab_size,
69            max_seq_len: self.max_seq_len,
70        }
71    }
72
73    /// Build config from a `ModelDefinition`, shared field extraction.
74    /// Variant-specific constructors below set `has_qk_norm` and fall back
75    /// to different `rope_theta` defaults when the checkpoint doesn't set one.
76    fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
77        let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
78        let head_dim = def
79            .extra_params
80            .get("head_dim")
81            .and_then(|v| v.as_u64())
82            .map(|v| v as usize)
83            .unwrap_or(def.hidden_size / def.num_attention_heads);
84        // Mistral / Gemma: "sliding_window" may be null (v0.2+) or a positive
85        // integer (v0.1). Non-null value passes through; missing/null → 0.
86        let sliding_window = def
87            .extra_params
88            .get("sliding_window")
89            .and_then(|v| v.as_u64())
90            .map(|v| v as usize)
91            .unwrap_or(0);
92
93        LlamaFamilyConfigBase {
94            hidden_size: def.hidden_size,
95            intermediate_size: def.intermediate_size,
96            num_heads: def.num_attention_heads,
97            num_kv_heads,
98            head_dim,
99            num_layers: def.num_hidden_layers,
100            vocab_size: def.vocab_size,
101            max_seq_len: def.max_position_embeddings,
102            rms_norm_eps: def.norm_eps as f32,
103            rope_theta_opt: def.rope_theta,
104            sliding_window,
105        }
106    }
107
108    fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
109        Self {
110            hidden_size: b.hidden_size,
111            intermediate_size: b.intermediate_size,
112            num_heads: b.num_heads,
113            num_kv_heads: b.num_kv_heads,
114            head_dim: b.head_dim,
115            num_layers: b.num_layers,
116            vocab_size: b.vocab_size,
117            max_seq_len: b.max_seq_len,
118            rms_norm_eps: b.rms_norm_eps,
119            rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
120            has_qk_norm,
121            sliding_window: b.sliding_window,
122        }
123    }
124
125    /// Qwen3: QK-norm on, rope_theta default 1e6.
126    pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
127        Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
128    }
129
130    /// Llama / Llama-2 / Llama-3: no QK-norm; rope_theta varies by version
131    /// (10k for Llama-2, 500k for Llama-3.1+) — use the checkpoint value or
132    /// fall back to the most common modern value.
133    pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
134        Self::from_base(Self::from_def_base(def), 500_000.0, false)
135    }
136
137    /// Qwen2 / Qwen2.5: structurally Llama; no QK-norm; rope_theta default 1e6.
138    pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
139        Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
140    }
141
142    /// Mistral: no QK-norm; `rope_theta` commonly 10_000 (v0.1 / v0.2).
143    /// Picks up `sliding_window` from the checkpoint's config.json
144    /// (Mistral v0.1: 4096; Mistral v0.2+: 0 / null).
145    pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
146        Self::from_base(Self::from_def_base(def), 10_000.0, false)
147    }
148}
149
150struct LlamaFamilyConfigBase {
151    hidden_size: usize,
152    intermediate_size: usize,
153    num_heads: usize,
154    num_kv_heads: usize,
155    head_dim: usize,
156    num_layers: usize,
157    vocab_size: usize,
158    max_seq_len: usize,
159    rms_norm_eps: f32,
160    rope_theta_opt: Option<f64>,
161    sliding_window: usize,
162}
163
164/// Per-layer weights. `Box<dyn Linear<B>>` means each projection can be
165/// Dense / GPTQ / AWQ / GGUF without the surrounding code caring.
166pub struct LlamaFamilyLayer<B: Backend> {
167    pub input_ln_w: B::Buffer,
168    pub qkv_proj: Box<dyn Linear<B>>,
169    /// QK-norm weight per head: `[head_dim]`. Optional for non-Qwen3 derivatives.
170    pub q_norm_w: Option<B::Buffer>,
171    pub k_norm_w: Option<B::Buffer>,
172    pub o_proj: Box<dyn Linear<B>>,
173    pub post_ln_w: B::Buffer,
174    pub gate_up_proj: Box<dyn Linear<B>>,
175    pub down_proj: Box<dyn Linear<B>>,
176}
177
178/// Precomputed RoPE cos/sin tables (shape `[max_seq, head_dim / 2]` each).
179pub struct RopeCache<B: Backend> {
180    pub cos: B::Buffer,
181    pub sin: B::Buffer,
182}
183
184/// Reusable per-layer scratch buffers sized for `max_tokens` tokens of a
185/// single forward pass (prefill or decode step).
186///
187/// Sized lazily on first use so tiny decode steps don't pay for prefill-sized
188/// buffers. Grows monotonically when a larger prefill arrives.
189pub struct LlamaFamilyScratch<B: Backend> {
190    /// Residual stream — wrapped in Option so decode_internal can
191    /// `.take()` it without needing an alloc placeholder.
192    ///
193    /// Why this matters for graph capture: the old pattern was
194    /// `mem::replace(&mut scratch.residual, B::alloc(1))` which creates a
195    /// 1-element buffer at every decode step. When graph capture is on,
196    /// that alloc-during-capture + drop-after-capture pair surfaces as
197    /// cuMemFreeAsync(INVALID_VALUE) because the free tries to release a
198    /// pointer the captured graph may still reference. Option::take leaves
199    /// None and moves the real buffer into a local — no spurious alloc.
200    pub residual: Option<B::Buffer>,
201    pub norm_out: B::Buffer,
202    pub qkv_out: B::Buffer,
203    // ── Per-item scratch for batched decode path ──────────────────────
204    // decode_batch_internal runs tokens=M batched ops for the GEMM-heavy
205    // half (norm, qkv_proj, split_qkv, o_proj, post_norm, gate_up, silu,
206    // down, residual_add) but must loop per-item for rope + KV append +
207    // attention (each item has its own KV cache at a different kv_len).
208    // These single-item buffers hold item i's slice during that loop.
209    /// Item-scope q_buf slice, sized `q_dim`.
210    pub q_single: B::Buffer,
211    pub k_single: B::Buffer,
212    pub v_single: B::Buffer,
213    pub q_head_major_single: B::Buffer,
214    pub k_head_major_single: B::Buffer,
215    pub v_head_major_single: B::Buffer,
216    pub attn_head_major_single: B::Buffer,
217    pub attn_flat_single: B::Buffer,
218    /// Batched logits output, sized `max_tokens * vocab_size`. Used only
219    /// in decode_batch; prefill/single-decode use the regular `logits`.
220    pub batch_logits: B::Buffer,
221    /// Token-major Q/K/V right after `split_qkv`. Stride: heads * hd per row.
222    pub q_buf: B::Buffer,
223    pub k_buf: B::Buffer,
224    pub v_buf: B::Buffer,
225    /// Head-major Q produced by `qk_norm_rope` — fed into `flash_attention`.
226    pub q_head_major: B::Buffer,
227    /// Head-major K/V staging — produced by `qk_norm_rope`, consumed by
228    /// `kv_cache_append_head_major` (no reuse after append).
229    pub k_head_major: B::Buffer,
230    pub v_head_major: B::Buffer,
231    /// Head-major attention output from `flash_attention`.
232    pub attn_head_major_out: B::Buffer,
233    /// Token-major attention output after `transpose_head_to_token`.
234    pub attn_flat: B::Buffer,
235    pub o_proj_out: B::Buffer,
236    pub gate_up_out: B::Buffer,
237    pub silu_out: B::Buffer,
238    pub mlp_out: B::Buffer,
239    /// Paged batched dispatch scratch (Phase 4b). Sized for
240    /// `FERRUM_PAGED_MAX_SEQS × q_dim` so multi-seq decode can fan
241    /// in M items' Q into a single buffer for one batched
242    /// `paged_decode_attention(num_seqs=M)` call. `None` when paged
243    /// mode is off.
244    pub paged_batch_q: Option<B::Buffer>,
245    pub paged_batch_o: Option<B::Buffer>,
246    /// Stacked per-seq block tables for batched paged dispatch.
247    /// Layout: `[max_M, max_blocks_per_seq]` u32. Written
248    /// host-side per decode_batch step.
249    pub paged_batch_block_tables: Option<B::Buffer>,
250    /// Stacked per-seq context lengths for batched paged dispatch
251    /// (`[max_M]` u32).
252    pub paged_batch_context_lens: Option<B::Buffer>,
253    /// `max_blocks_per_seq` value baked into the stacked block_tables
254    /// stride. Set when `paged_batch_block_tables` is allocated.
255    pub paged_max_blocks_per_seq: usize,
256    /// Last token's hidden state (`[h]`). For prefill this is populated via
257    /// `copy_slice(residual, (seq_len-1)*h, ..)`; for decode `residual` already
258    /// holds only 1 row so `last_hidden` is unused on that path.
259    pub last_hidden: B::Buffer,
260    /// Final-norm output for the last token (`[h]`).
261    pub last_normed: B::Buffer,
262    /// lm_head logits (`[vocab]`).
263    pub logits: B::Buffer,
264    /// The max tokens-per-step this scratch has been sized for.
265    pub max_tokens: usize,
266}
267
268impl<B: Backend> LlamaFamilyScratch<B> {
269    fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
270        let h = cfg.hidden_size;
271        let im = cfg.intermediate_size;
272        let q_dim = cfg.num_heads * cfg.head_dim;
273        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
274        let qkv_dim = q_dim + 2 * kv_dim;
275        let t = max_tokens;
276        Self {
277            residual: Some(B::alloc(t * h)),
278            norm_out: B::alloc(t * h),
279            qkv_out: B::alloc(t * qkv_dim),
280            q_buf: B::alloc(t * q_dim),
281            k_buf: B::alloc(t * kv_dim),
282            v_buf: B::alloc(t * kv_dim),
283            q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
284            k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
285            v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
286            attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
287            attn_flat: B::alloc(t * q_dim),
288            o_proj_out: B::alloc(t * h),
289            gate_up_out: B::alloc(t * 2 * im),
290            silu_out: B::alloc(t * im),
291            mlp_out: B::alloc(t * h),
292            last_hidden: B::alloc(h),
293            last_normed: B::alloc(h),
294            logits: B::alloc(cfg.vocab_size),
295            q_single: B::alloc(q_dim),
296            k_single: B::alloc(kv_dim),
297            v_single: B::alloc(kv_dim),
298            q_head_major_single: B::alloc(q_dim),
299            k_head_major_single: B::alloc(kv_dim),
300            v_head_major_single: B::alloc(kv_dim),
301            attn_head_major_single: B::alloc(q_dim),
302            attn_flat_single: B::alloc(q_dim),
303            batch_logits: B::alloc(t * cfg.vocab_size),
304            // Paged batched dispatch scratch. None until `enable_paged_batch`
305            // is called from `ensure_kv` once the model knows max_seqs +
306            // max_blocks_per_seq. This avoids paying the alloc cost when
307            // paged mode is off.
308            paged_batch_q: None,
309            paged_batch_o: None,
310            paged_batch_block_tables: None,
311            paged_batch_context_lens: None,
312            paged_max_blocks_per_seq: 0,
313            max_tokens: t,
314        }
315    }
316
317    /// Allocate scratch for batched paged dispatch (Phase 4b). Called
318    /// lazily from `ensure_kv` once paged mode is enabled and we know
319    /// the pool dimensions. Idempotent.
320    fn enable_paged_batch(
321        &mut self,
322        cfg: &LlamaFamilyConfig,
323        max_seqs: usize,
324        max_blocks_per_seq: usize,
325    ) {
326        if self.paged_batch_q.is_some() {
327            return;
328        }
329        let q_dim = cfg.num_heads * cfg.head_dim;
330        self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
331        self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
332        self.paged_batch_block_tables = Some(B::alloc_u32(max_seqs * max_blocks_per_seq));
333        self.paged_batch_context_lens = Some(B::alloc_u32(max_seqs));
334        self.paged_max_blocks_per_seq = max_blocks_per_seq;
335    }
336}
337
338/// Qwen3 model — decoder-only LLM, one per (backend, weights) combination.
339///
340/// Holds all parameters, scratch space, RoPE cache, and per-sequence KV caches.
341pub struct LlamaFamilyModel<B: Backend> {
342    pub cfg: LlamaFamilyConfig,
343    pub runtime_cfg: LlmRuntimeConfig,
344
345    /// Token embedding table. `None` for backbone-only models (e.g. the
346    /// Qwen3-TTS Talker, which embeds inputs externally and feeds via
347    /// `prefill_from_embeds`).
348    pub embed: Option<B::Buffer>,
349    pub layers: Vec<LlamaFamilyLayer<B>>,
350    pub final_norm_w: B::Buffer,
351    /// LM output head. `None` for backbone-only models.
352    pub lm_head: Option<Box<dyn Linear<B>>>,
353
354    pub rope: RopeCache<B>,
355    pub scratch: LlamaFamilyScratch<B>,
356
357    /// Per-sequence KV caches, one `Vec<KvCache<B>>` of length `num_layers`.
358    ///
359    /// Two layouts overlay this same map:
360    /// - **Contiguous mode** (default): each cache holds its own
361    ///   `[num_kv_heads, capacity, head_dim]` k/v buffers.
362    /// - **Paged mode** (`FERRUM_METAL_PAGED_KV=1`): k/v are unused
363    ///   placeholders; the real K/V live in [`Self::paged_pools`] and
364    ///   the cache's `block_table` + `context_lens` index into them.
365    pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
366    /// Free pool of pre-allocated KV cache slots. Released caches return
367    /// here instead of being dropped, so their device pointers stay valid
368    /// across requests — critical for graph capture (pointers baked into
369    /// the captured graph would otherwise dangle).
370    kv_free_pool: Vec<Vec<KvCache<B>>>,
371
372    // ── Paged-KV multi-seq state (Phase 4) ─────────────────────────────
373    //
374    // Only populated when `FERRUM_METAL_PAGED_KV=1`. When set, every
375    // `kv_caches` entry becomes a "view" into the shared pool: its
376    // `k` / `v` buffers are placeholders; reads / writes go through
377    // `paged_pools[layer].k` / `.v` indexed via the cache's
378    // `block_table`. Multiple cache_ids share the same pool, with
379    // disjoint physical block sets owned by `paged_block_alloc`.
380    //
381    /// Shared K/V pools, one per layer. Sized at model load time for the
382    /// configured `MAX_RUNNING_REQUESTS × max_blocks_per_seq` blocks.
383    pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
384    /// Block allocator hands out physical block indices from the pool.
385    /// `Mutex` because the engine can call `ensure_kv` / `release_kv`
386    /// from multiple threads in concurrent serving.
387    pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
388
389    // ── Graph capture state (CUDA only; harmless no-op on other backends) ──
390    /// Count of eager decode steps run so far. After `GRAPH_WARMUP`, the
391    /// next step captures the decode flow as a graph.
392    graph_warmup: usize,
393    /// True if capture was attempted but failed (e.g. backend doesn't
394    /// support graph capture). Stops further attempts, falls back to eager.
395    graph_capture_failed: bool,
396}
397
398impl<B: Backend> LlamaFamilyModel<B> {
399    /// Build a Qwen3 model from weights provided by the loader.
400    ///
401    /// The loader decides per-projection whether to instantiate DenseLinear,
402    /// GptqLinear, etc. — this code doesn't care.
403    pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
404        // Invalidate any graph from a previously-loaded model. The captured
405        // graph references the old model's scratch buffers; a fresh model
406        // gets fresh scratch, so reusing the graph would read/write freed
407        // pointers. Matters for test suites where multiple models coexist.
408        {
409            let mut ctx = B::new_context();
410            B::reset_graph(&mut ctx);
411        }
412        let rope = build_rope_cache::<B>(&cfg);
413        let scratch = LlamaFamilyScratch::alloc(&cfg, 1); // decode-sized; prefill resizes
414
415        // Embedding: plain tensor (no projection math, just lookup).
416        let embed = loader.load_tensor("model.embed_tokens.weight")?;
417
418        // Per-layer weights.
419        let mut layers = Vec::with_capacity(cfg.num_layers);
420        for li in 0..cfg.num_layers {
421            let prefix = format!("model.layers.{li}");
422            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
423            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
424            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
425            let post_ln_w =
426                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
427            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
428            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
429
430            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
431                let q = loader
432                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
433                    .ok();
434                let k = loader
435                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
436                    .ok();
437                (q, k)
438            } else {
439                (None, None)
440            };
441
442            layers.push(LlamaFamilyLayer {
443                input_ln_w,
444                qkv_proj,
445                q_norm_w,
446                k_norm_w,
447                o_proj,
448                post_ln_w,
449                gate_up_proj,
450                down_proj,
451            });
452        }
453
454        let final_norm_w = loader.load_tensor("model.norm.weight")?;
455
456        // LM head: either dedicated `lm_head.weight` or tied to embedding.
457        // Many models (Qwen3-4B, Llama-3.2-1B, some Qwen2.5) use TIED
458        // embeddings — lm_head shares weights with model.embed_tokens. When
459        // no dedicated lm_head tensor exists, re-load the embed tensor as a
460        // DenseLinear. This duplicates the buffer (memory cost = vocab*h*2
461        // bytes, e.g. ~770MB for Qwen3-4B) but keeps the Linear trait's
462        // owned-weights invariant. Sharing via Arc is a future optimisation.
463        let lm_head = if loader.has_tensor("lm_head.weight") {
464            loader.load_linear("lm_head")?
465        } else {
466            tracing::info!(
467                "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
468            );
469            let as_linear = loader.load_linear("model.embed_tokens")?;
470            // Sanity check: shape must be [vocab, hidden].
471            if as_linear.out_features() != cfg.vocab_size
472                || as_linear.in_features() != cfg.hidden_size
473            {
474                return Err(ferrum_types::FerrumError::model(format!(
475                    "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
476                    as_linear.out_features(),
477                    as_linear.in_features(),
478                    cfg.vocab_size,
479                    cfg.hidden_size
480                )));
481            }
482            as_linear
483        };
484
485        let runtime_cfg = cfg.to_runtime();
486        Ok(Self {
487            cfg,
488            runtime_cfg,
489            embed: Some(embed),
490            layers,
491            final_norm_w,
492            lm_head: Some(lm_head),
493            rope,
494            scratch,
495            kv_caches: HashMap::new(),
496            kv_free_pool: Vec::new(),
497            paged_pools: None,
498            paged_block_alloc: None,
499            graph_warmup: 0,
500            graph_capture_failed: false,
501        })
502    }
503
504    /// Build a backbone-only Qwen3 transformer stack (no embed, no lm_head).
505    ///
506    /// Intended for composing the transformer inside a larger model where
507    /// embedding and output-head logic differs from the standard LLM path —
508    /// e.g. Qwen3-TTS Talker uses dual text/codec embeddings with a projection
509    /// MLP, and a codec_head output. The caller drives forward via
510    /// `prefill_from_embeds` / `decode_from_embed`.
511    ///
512    /// Loader must provide: per-layer weights under `model.layers.{i}.*` and
513    /// the final `model.norm.weight`. `model.embed_tokens` and `lm_head`
514    /// are NOT read.
515    pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
516        // See `new` — invalidate stale graph referring to prior model's scratch.
517        {
518            let mut ctx = B::new_context();
519            B::reset_graph(&mut ctx);
520        }
521        let rope = build_rope_cache::<B>(&cfg);
522        let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
523
524        let mut layers = Vec::with_capacity(cfg.num_layers);
525        for li in 0..cfg.num_layers {
526            let prefix = format!("model.layers.{li}");
527            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
528            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
529            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
530            let post_ln_w =
531                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
532            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
533            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
534
535            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
536                let q = loader
537                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
538                    .ok();
539                let k = loader
540                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
541                    .ok();
542                (q, k)
543            } else {
544                (None, None)
545            };
546
547            layers.push(LlamaFamilyLayer {
548                input_ln_w,
549                qkv_proj,
550                q_norm_w,
551                k_norm_w,
552                o_proj,
553                post_ln_w,
554                gate_up_proj,
555                down_proj,
556            });
557        }
558
559        let final_norm_w = loader.load_tensor("model.norm.weight")?;
560
561        let runtime_cfg = cfg.to_runtime();
562        Ok(Self {
563            cfg,
564            runtime_cfg,
565            embed: None,
566            layers,
567            final_norm_w,
568            lm_head: None,
569            rope,
570            scratch,
571            kv_caches: HashMap::new(),
572            kv_free_pool: Vec::new(),
573            paged_pools: None,
574            paged_block_alloc: None,
575            graph_warmup: 0,
576            graph_capture_failed: false,
577        })
578    }
579
580    /// Grow scratch buffers if `tokens` exceeds the current sizing.
581    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
582        if self.scratch.max_tokens < tokens {
583            // Any captured decode graph holds pointers to the old scratch
584            // buffers; those are about to be freed. Invalidate first so the
585            // next decode falls back to eager + re-captures with fresh ptrs.
586            // Critical for multi-turn chat (turn N+1's prefill may grow scratch).
587            {
588                let mut ctx = B::new_context();
589                B::reset_graph(&mut ctx);
590            }
591            self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
592            self.graph_warmup = 0;
593            self.graph_capture_failed = false;
594        }
595    }
596
597    /// Ensure per-layer KV caches exist for `cache_id`, pre-allocated to
598    /// `max_seq_len` slots per head. Enables the in-place
599    /// `kv_cache_append_head_major` path — no realloc per layer.
600    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
601        if self.kv_caches.contains_key(cache_id) {
602            return;
603        }
604        let nkv = self.cfg.num_kv_heads;
605        let hd = self.cfg.head_dim;
606        // KV capacity defaults to a chat-friendly 4096 to keep the working
607        // set sane on a 32 GB Mac (a 48-layer / 4-kv-head / 128-head_dim
608        // model spends ~786 MB on 4096 KV slots, vs ~6 GB on the model's
609        // declared 32K which would push the 17 GB Qwen3-30B-A3B model into
610        // swap). `FERRUM_KV_CAPACITY=N` overrides; clamp to the model's
611        // declared max so we never lie to the model about its window.
612        let model_max = self.cfg.max_seq_len;
613        // 512 in 0.7.2 — matches the value used in
614        // docs/bench/macos-2026-05-02 to get the published numbers.
615        // pre-0.7.2 default of 4096 was safe only because paged-KV was
616        // opt-in (pool wasn't allocated). With paged-KV now on by
617        // default + MAX_SEQS=32, the pool occupies physical memory:
618        // ~3 GB on Qwen3-30B-A3B Q4_K_M leaves 18 GB weights + 3 GB pool
619        // = 21 GB, fits comfortably on a 32 GB Mac. Long-context users
620        // can `FERRUM_KV_CAPACITY=4096` and accept lower max_seqs.
621        const DEFAULT_KV_CAPACITY: usize = 512;
622        let max = std::env::var("FERRUM_KV_CAPACITY")
623            .ok()
624            .and_then(|s| s.parse::<usize>().ok())
625            .map(|cap| cap.min(model_max))
626            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY));
627
628        // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches every cache
629        // for this model into block-table-indirect layout. Kernels from
630        // PR #68 (decode read) + PR #69 (decode write) handle the
631        // indirect addressing; the LlamaFamily decode path below picks
632        // them up automatically by checking `cache.block_size > 0`.
633        //
634        // Pool sizing: round capacity up to a multiple of block_size,
635        // identity-assign logical→physical block. Memory footprint is
636        // the same as contiguous (within block_size rounding); the
637        // benefit only shows up under multi-seq sharing in Phase 4+.
638        // Default ON when the backend supports paged-KV (Metal). Users
639        // can force off with `FERRUM_METAL_PAGED_KV=0`. The flag was
640        // opt-in pre-0.7.2; flipping the default so default `ferrum
641        // serve` matches the bench-quality numbers without requiring
642        // env-var knowledge.
643        let paged = std::env::var("FERRUM_METAL_PAGED_KV")
644            .map(|v| v != "0")
645            .unwrap_or_else(|_| B::supports_paged_kv());
646        const PAGED_BLOCK_SIZE: usize = 16;
647
648        // Phase 4 shared-pool sizing. The pool sees ALL concurrent
649        // sequences; per-cache_id state just owns indices into it.
650        // Default 32: covers c=16 burst with 2× headroom for the
651        // fresh-cache-id-per-request pattern that bench/server harnesses
652        // use. Pool memory is `max_seqs × max_blocks_per_seq` total
653        // blocks — we lowered DEFAULT_KV_CAPACITY to 2048 so this 2× max_seqs
654        // bump keeps the pool footprint identical to the pre-0.7.2 default.
655        let max_seqs = std::env::var("FERRUM_PAGED_MAX_SEQS")
656            .ok()
657            .and_then(|s| s.parse::<usize>().ok())
658            .unwrap_or(32);
659        let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
660        let total_pool_blocks = max_seqs * max_blocks_per_seq;
661
662        // Lazy-allocate the shared paged pools on the FIRST paged
663        // ensure_kv call. Pools are big — for Llama-8B (8 kv_heads,
664        // head_dim=128) at 16 seqs × 256 blocks × 16 slots = 65536 KV
665        // slots: 65536 * 8 * 128 * 4 = 256 MB per layer × 32 layers
666        // = 8 GB total. Sized this large only because `max_seqs=16`
667        // is the default; lower it via env to shrink.
668        if paged && self.paged_pools.is_none() {
669            let mut pools = Vec::with_capacity(self.cfg.num_layers);
670            for _ in 0..self.cfg.num_layers {
671                let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
672                pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
673            }
674            self.paged_pools = Some(pools);
675            self.paged_block_alloc = Some(std::sync::Mutex::new(
676                crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
677            ));
678        }
679        // Phase 4b: ensure batched-dispatch scratch is allocated whenever
680        // paged is on. Idempotent — re-init is a no-op if already
681        // sized. Has to live outside the `paged_pools.is_none()` branch
682        // because `ensure_scratch` may have replaced the scratch struct
683        // since the pools were first allocated.
684        if paged {
685            self.scratch
686                .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
687        }
688
689        // Try pool first — reused buffers have stable device pointers,
690        // so a captured decode graph can be replayed for this request too.
691        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
692            (0..self.cfg.num_layers)
693                .map(|_| {
694                    if paged {
695                        // Paged mode: cache holds metadata only. K/V
696                        // are 1-element placeholders (allocated cheaply
697                        // since Backend::alloc requires a non-zero
698                        // size on most backends). The real data lives
699                        // in `self.paged_pools[li].{k,v}`.
700                        let mut block_table = B::alloc_u32(max_blocks_per_seq);
701                        let mut context_lens = B::alloc_u32(1);
702                        let mut bt_ctx = B::new_context();
703                        B::write_u32(&mut bt_ctx, &mut context_lens, &[0u32]);
704                        B::sync(&mut bt_ctx);
705                        KvCache {
706                            k: B::alloc(1),
707                            v: B::alloc(1),
708                            len: 0,
709                            capacity: max_blocks_per_seq * PAGED_BLOCK_SIZE,
710                            num_kv_heads: nkv,
711                            head_dim: hd,
712                            block_size: PAGED_BLOCK_SIZE,
713                            block_table: Some(block_table),
714                            context_lens: Some(context_lens),
715                            paged_block_indices: Vec::new(),
716                        }
717                    } else {
718                        KvCache {
719                            k: B::alloc(nkv * max * hd),
720                            v: B::alloc(nkv * max * hd),
721                            len: 0,
722                            capacity: max,
723                            num_kv_heads: nkv,
724                            head_dim: hd,
725                            block_size: 0,
726                            block_table: None,
727                            context_lens: None,
728                            paged_block_indices: Vec::new(),
729                        }
730                    }
731                })
732                .collect()
733        });
734
735        // Allocate physical blocks for THIS cache_id from the shared
736        // pool. We allocate all `max_blocks_per_seq` upfront for
737        // simplicity (matches contig's "pre-alloc to capacity"
738        // semantics); a smarter Phase 4b can grow on demand to save
739        // pool occupancy.
740        if paged {
741            let alloc_arc = self
742                .paged_block_alloc
743                .as_ref()
744                .expect("paged_block_alloc must be initialised when paged=true");
745            // Recover from a previously-poisoned mutex instead of panicking
746            // (poison just means a prior holder panicked; the BlockAllocator
747            // state is still intact since allocate_n is fail-safe).
748            let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
749            let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
750                Ok(idx) => idx,
751                Err(e) => {
752                    // Pool exhaustion is a back-pressure signal, not a crash.
753                    // Drop the lock, return the cache to the free pool, and
754                    // bail before inserting it into kv_caches. The downstream
755                    // call will then fail with a clean per-request error
756                    // ("ensure_kv must be called before ...") instead of
757                    // dragging every other in-flight request down with it.
758                    drop(alloc);
759                    self.kv_free_pool.push(caches);
760                    eprintln!(
761                        "[ferrum] paged KV pool exhausted on ensure_kv for \
762                         cache_id={cache_id:?}: {e}. Increase \
763                         FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
764                         throttle concurrent requests.",
765                    );
766                    return;
767                }
768            };
769            // Write the block table to each layer's cache. All layers
770            // share the same logical→physical mapping for this seq.
771            // Also stash the host-side index list so release_kv can
772            // return them to the allocator without a device readback.
773            let mut padded = block_indices.clone();
774            padded.resize(max_blocks_per_seq, 0);
775            let mut ctx_tmp = B::new_context();
776            for c in caches.iter_mut() {
777                if let Some(bt) = c.block_table.as_mut() {
778                    B::write_u32(&mut ctx_tmp, bt, &padded);
779                }
780                c.paged_block_indices = block_indices.clone();
781            }
782            B::sync(&mut ctx_tmp);
783        }
784
785        // Reset logical length; buffers stay. No need to zero the memory —
786        // the kv_cache_append writes new K/V in place, and attention only
787        // reads up to `cache_len`.
788        for c in caches.iter_mut() {
789            c.len = 0;
790            if let Some(cl) = c.context_lens.as_mut() {
791                let mut ctx_tmp = B::new_context();
792                B::write_u32(&mut ctx_tmp, cl, &[0u32]);
793                B::sync(&mut ctx_tmp);
794            }
795        }
796        self.kv_caches.insert(cache_id.to_string(), caches);
797    }
798
799    /// Run one transformer layer. Mutates `residual` in place.
800    ///
801    /// `pos_offset` is the absolute position of token 0 in this batch
802    /// (decode: `pos`; prefill: 0). `tokens` is the batch size.
803    #[allow(clippy::too_many_arguments)]
804    pub(crate) fn forward_layer(
805        &mut self,
806        ctx: &mut B::Context,
807        li: usize,
808        cache_id: &str,
809        residual: &mut B::Buffer,
810        pos_offset: usize,
811        tokens: usize,
812    ) {
813        let layer = &self.layers[li];
814        let cfg = &self.cfg;
815        let h = cfg.hidden_size;
816        let nh = cfg.num_heads;
817        let nkv = cfg.num_kv_heads;
818        let hd = cfg.head_dim;
819        let im = cfg.intermediate_size;
820        let eps = cfg.rms_norm_eps;
821        let q_dim = nh * hd;
822        let kv_dim = nkv * hd;
823
824        // 1. Input RMSNorm
825        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
826            B::sync(ctx);
827            Some(std::time::Instant::now())
828        } else {
829            None
830        };
831        B::rms_norm(
832            ctx,
833            residual,
834            &layer.input_ln_w,
835            eps,
836            &mut self.scratch.norm_out,
837            tokens,
838            h,
839        );
840        if let Some(t0) = _t0 {
841            B::sync(ctx);
842            NORM_TIME_US.fetch_add(
843                t0.elapsed().as_micros() as u64,
844                std::sync::atomic::Ordering::Relaxed,
845            );
846            NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
847        }
848
849        // 2. Fused QKV projection (Linear dispatches to Dense/GPTQ/AWQ/GGUF)
850        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
851            B::sync(ctx);
852            Some(std::time::Instant::now())
853        } else {
854            None
855        };
856        layer.qkv_proj.forward(
857            ctx,
858            &self.scratch.norm_out,
859            &mut self.scratch.qkv_out,
860            tokens,
861        );
862        if let Some(t0) = _t0 {
863            B::sync(ctx);
864            MATMUL_TIME_US.fetch_add(
865                t0.elapsed().as_micros() as u64,
866                std::sync::atomic::Ordering::Relaxed,
867            );
868            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
869        }
870
871        // 3-5. Fused split-QKV + QK-norm + RoPE + cache-write.
872        //
873        // Single Metal dispatch replaces the (split_qkv → 3× qk_norm_rope
874        // → kv_cache_append_head_major) five-launch chain on the decode
875        // hot path. Reads qkv_out once, writes Q to head-major scratch
876        // and K/V straight into the pre-allocated KV cache slot at
877        // `cache_len + tok`. Saves 4 dispatches per layer when the
878        // backend implements the fused kernel; CPU and other backends
879        // keep using the unfused chain via the Unsupported fallbacks.
880        //
881        // qk_mode: 1 = norm + RoPE (Qwen3); 2 = RoPE only (Llama).
882        // V always passes apply_norm=0.
883        let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
884        let dummy = &layer.input_ln_w;
885        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
886        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
887
888        // Grab the per-layer KV cache up front so the deepest fusion can
889        // write K/V straight into it.
890        //
891        // Paged mode: also need this layer's shared pool buffers
892        // (self.paged_pools[li]). The pool is a separate field from
893        // kv_caches, so we take a raw pointer to its (k, v) here while
894        // we still hold &mut self, then deref via unsafe inside the
895        // paged dispatch below. Safety: paged_pools is allocated once
896        // and never resized; we don't touch self.paged_pools while the
897        // pointer is in use.
898        let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
899            if let Some(pools) = self.paged_pools.as_mut() {
900                let pool = &mut pools[li];
901                Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
902            } else {
903                None
904            };
905        let caches = self
906            .kv_caches
907            .get_mut(cache_id)
908            .expect("ensure_kv must be called before forward_layer");
909        let cache = &mut caches[li];
910        let cache_len_before = cache.len;
911        let cache_capacity = cache.capacity;
912
913        // Defense in depth: refuse to write past the KV buffer. The
914        // graceful path is the caller pre-checking via `kv_capacity()`
915        // and either compacting or refusing the request; this panic only
916        // fires when that contract is broken (and silent overflow would
917        // otherwise corrupt the cache + adjacent allocations).
918        if cache_len_before + tokens > cache_capacity {
919            panic!(
920                "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
921                cache_len_before + tokens
922            );
923        }
924
925        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
926            B::sync(ctx);
927            Some(std::time::Instant::now())
928        } else {
929            None
930        };
931        // Paged-KV path: when the cache was allocated with paged
932        // metadata (`block_size > 0`), use the paged write kernel
933        // which fans out into the block pool via `block_table`.
934        // Falls back to contiguous if Backend doesn't implement it.
935        let used_qkv_into_cache = if cache.block_size > 0 {
936            let bt = cache
937                .block_table
938                .as_ref()
939                .expect("paged cache missing block_table");
940            let num_blocks_per_seq = cache.capacity / cache.block_size;
941            // Paged mode: K/V live in the shared pool, not cache.k/.v.
942            let (pool_k_ptr, pool_v_ptr) =
943                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
944            // SAFETY: paged_pools is allocated once and never resized;
945            // we do not touch self.paged_pools concurrently.
946            let pool_k = unsafe { &mut *pool_k_ptr };
947            let pool_v = unsafe { &mut *pool_v_ptr };
948            B::split_qkv_norm_rope_into_paged_cache(
949                ctx,
950                &self.scratch.qkv_out,
951                0, // qkv_byte_offset: single-seq dispatch reads from start
952                q_norm_w,
953                k_norm_w,
954                &self.rope.cos,
955                &self.rope.sin,
956                &mut self.scratch.q_head_major,
957                0, // q_out_byte_offset: writes to start of head-major scratch
958                pool_k,
959                pool_v,
960                bt,
961                tokens,
962                nh,
963                nkv,
964                hd,
965                pos_offset,
966                eps,
967                qk_mode,
968                cache_len_before,
969                cache.block_size,
970                num_blocks_per_seq,
971            )
972            .is_ok()
973        } else {
974            B::split_qkv_norm_rope_into_cache(
975                ctx,
976                &self.scratch.qkv_out,
977                q_norm_w,
978                k_norm_w,
979                &self.rope.cos,
980                &self.rope.sin,
981                &mut self.scratch.q_head_major,
982                &mut cache.k,
983                &mut cache.v,
984                tokens,
985                nh,
986                nkv,
987                hd,
988                pos_offset,
989                eps,
990                qk_mode,
991                cache_len_before,
992                cache_capacity,
993            )
994            .is_ok()
995        };
996        if !used_qkv_into_cache {
997            // Fallback 1: fused split-QKV-norm-rope to head-major scratch
998            // (PR #47 path).
999            let used_fused_qkv = B::split_qkv_norm_rope(
1000                ctx,
1001                &self.scratch.qkv_out,
1002                q_norm_w,
1003                k_norm_w,
1004                &self.rope.cos,
1005                &self.rope.sin,
1006                &mut self.scratch.q_head_major,
1007                &mut self.scratch.k_head_major,
1008                &mut self.scratch.v_head_major,
1009                tokens,
1010                nh,
1011                nkv,
1012                hd,
1013                pos_offset,
1014                eps,
1015                qk_mode,
1016            )
1017            .is_ok();
1018            if !used_fused_qkv {
1019                // Fallback 2: original four-launch chain.
1020                B::split_qkv(
1021                    ctx,
1022                    &self.scratch.qkv_out,
1023                    &mut self.scratch.q_buf,
1024                    &mut self.scratch.k_buf,
1025                    &mut self.scratch.v_buf,
1026                    tokens,
1027                    q_dim,
1028                    kv_dim,
1029                );
1030                B::qk_norm_rope(
1031                    ctx,
1032                    &self.scratch.q_buf,
1033                    q_norm_w,
1034                    &self.rope.cos,
1035                    &self.rope.sin,
1036                    &mut self.scratch.q_head_major,
1037                    tokens,
1038                    nh,
1039                    hd,
1040                    pos_offset,
1041                    eps,
1042                    qk_mode,
1043                );
1044                B::qk_norm_rope(
1045                    ctx,
1046                    &self.scratch.k_buf,
1047                    k_norm_w,
1048                    &self.rope.cos,
1049                    &self.rope.sin,
1050                    &mut self.scratch.k_head_major,
1051                    tokens,
1052                    nkv,
1053                    hd,
1054                    pos_offset,
1055                    eps,
1056                    qk_mode,
1057                );
1058                B::qk_norm_rope(
1059                    ctx,
1060                    &self.scratch.v_buf,
1061                    dummy,
1062                    &self.rope.cos,
1063                    &self.rope.sin,
1064                    &mut self.scratch.v_head_major,
1065                    tokens,
1066                    nkv,
1067                    hd,
1068                    pos_offset,
1069                    eps,
1070                    0,
1071                );
1072            }
1073            B::kv_cache_append_head_major(
1074                ctx,
1075                &mut cache.k,
1076                &mut cache.v,
1077                cache.len,
1078                cache.capacity,
1079                &self.scratch.k_head_major,
1080                &self.scratch.v_head_major,
1081                tokens,
1082                nkv,
1083                hd,
1084            );
1085        }
1086        if let Some(t0) = _t0 {
1087            B::sync(ctx);
1088            QKR_TIME_US.fetch_add(
1089                t0.elapsed().as_micros() as u64,
1090                std::sync::atomic::Ordering::Relaxed,
1091            );
1092            QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1093        }
1094        cache.len += tokens;
1095        let kv_len = cache.len;
1096        let kv_stride = cache.capacity;
1097
1098        // 6. Flash attention.
1099        //    Paged path: when the cache uses block layout, dispatch the
1100        //    paged_decode_attention kernel; for q_len > 1 (prefill),
1101        //    iterate token-by-token (kernel only handles q_len=1 right
1102        //    now — Phase 4 will add a paged Q-tiled path).
1103        //    Contiguous path: existing flash_attention dispatch.
1104        let _attn_t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1105            B::sync(ctx);
1106            Some(std::time::Instant::now())
1107        } else {
1108            None
1109        };
1110        if cache.block_size > 0 {
1111            let bt = cache
1112                .block_table
1113                .as_ref()
1114                .expect("paged cache missing block_table");
1115            let cl_buf = cache
1116                .context_lens
1117                .as_mut()
1118                .expect("paged cache missing context_lens");
1119            let num_blocks_per_seq = cache.capacity / cache.block_size;
1120            // Paged mode: K/V come from the shared pool.
1121            let (pool_k_ptr, pool_v_ptr) =
1122                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1123            // SAFETY: same as the write-side above; pool buffers are
1124            // allocated-once and never moved while we hold the pointer.
1125            let pool_k = unsafe { &*pool_k_ptr };
1126            let pool_v = unsafe { &*pool_v_ptr };
1127            // Single dispatch handles both decode (q_len=1) and causal
1128            // prefill (q_len>1). The kernel computes per-token causal
1129            // limit as `context_lens[seq] - (q_len - 1 - q_token_idx)`,
1130            // so we set context_lens to the FINAL kv_len after this
1131            // batch's writes.
1132            let final_kv_len = cache.len as u32;
1133            B::write_u32(ctx, cl_buf, &[final_kv_len]);
1134            B::paged_decode_attention(
1135                ctx,
1136                &self.scratch.q_head_major,
1137                pool_k,
1138                pool_v,
1139                &mut self.scratch.attn_head_major_out,
1140                bt,
1141                cl_buf,
1142                1, // num_seqs (single-seq dispatch; multi-seq is fan-in via forward_layer_batched, Phase 4b)
1143                nh,
1144                nkv,
1145                hd,
1146                cache.block_size,
1147                num_blocks_per_seq,
1148                tokens, // q_len
1149            )
1150            .expect("paged_decode_attention");
1151        } else {
1152            //    `causal` is always true for decoder-only LLMs — every query must
1153            //    mask out future tokens. Sliding-window models (Mistral v0.1) narrow
1154            //    the lower bound via `sliding_window`.
1155            let attn_cfg = ferrum_kernels::backend::AttnConfig {
1156                num_heads: nh,
1157                num_kv_heads: nkv,
1158                head_dim: hd,
1159                causal: true,
1160                scale: 1.0 / (hd as f32).sqrt(),
1161                kv_seq_stride: kv_stride,
1162                sliding_window: cfg.sliding_window,
1163            };
1164            B::flash_attention(
1165                ctx,
1166                &self.scratch.q_head_major,
1167                &cache.k,
1168                &cache.v,
1169                &mut self.scratch.attn_head_major_out,
1170                1,
1171                tokens,
1172                kv_len,
1173                pos_offset,
1174                &attn_cfg,
1175            );
1176        }
1177        if let Some(t0) = _attn_t0 {
1178            B::sync(ctx);
1179            ATTN_TIME_US.fetch_add(
1180                t0.elapsed().as_micros() as u64,
1181                std::sync::atomic::Ordering::Relaxed,
1182            );
1183            ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1184        }
1185
1186        // 7. Untranspose head-major → token-major for O-proj input.
1187        //
1188        // For tokens=1 the head-major and token-major layouts collapse
1189        // to the same flat [heads * head_dim] vector, so the dispatch is
1190        // an identity memcpy — skip it and point o_proj at the
1191        // head-major buffer directly.
1192        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1193            B::sync(ctx);
1194            Some(std::time::Instant::now())
1195        } else {
1196            None
1197        };
1198        let attn_token_major = if tokens == 1 {
1199            &self.scratch.attn_head_major_out
1200        } else {
1201            B::transpose_head_to_token(
1202                ctx,
1203                &self.scratch.attn_head_major_out,
1204                &mut self.scratch.attn_flat,
1205                tokens,
1206                nh,
1207                hd,
1208            );
1209            &self.scratch.attn_flat
1210        };
1211        if let Some(t0) = _t0 {
1212            B::sync(ctx);
1213            OTHER_TIME_US.fetch_add(
1214                t0.elapsed().as_micros() as u64,
1215                std::sync::atomic::Ordering::Relaxed,
1216            );
1217            OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1218        }
1219
1220        // 8. O projection
1221        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1222            B::sync(ctx);
1223            Some(std::time::Instant::now())
1224        } else {
1225            None
1226        };
1227        layer
1228            .o_proj
1229            .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1230        if let Some(t0) = _t0 {
1231            B::sync(ctx);
1232            MATMUL_TIME_US.fetch_add(
1233                t0.elapsed().as_micros() as u64,
1234                std::sync::atomic::Ordering::Relaxed,
1235            );
1236            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1237        }
1238
1239        // 9. Fused residual-add + post-attention RMSNorm.
1240        //    Writes the new residual back into `residual` and the normed
1241        //    value into `norm_out`.
1242        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1243            B::sync(ctx);
1244            Some(std::time::Instant::now())
1245        } else {
1246            None
1247        };
1248        B::fused_add_rms_norm(
1249            ctx,
1250            residual,
1251            &self.scratch.o_proj_out,
1252            &layer.post_ln_w,
1253            eps,
1254            &mut self.scratch.norm_out,
1255            tokens,
1256            h,
1257        );
1258        if let Some(t0) = _t0 {
1259            B::sync(ctx);
1260            NORM_TIME_US.fetch_add(
1261                t0.elapsed().as_micros() as u64,
1262                std::sync::atomic::Ordering::Relaxed,
1263            );
1264            NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1265        }
1266
1267        // 10. Fused gate+up projection
1268        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1269            B::sync(ctx);
1270            Some(std::time::Instant::now())
1271        } else {
1272            None
1273        };
1274        layer.gate_up_proj.forward(
1275            ctx,
1276            &self.scratch.norm_out,
1277            &mut self.scratch.gate_up_out,
1278            tokens,
1279        );
1280        if let Some(t0) = _t0 {
1281            B::sync(ctx);
1282            MATMUL_TIME_US.fetch_add(
1283                t0.elapsed().as_micros() as u64,
1284                std::sync::atomic::Ordering::Relaxed,
1285            );
1286            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1287        }
1288
1289        // 11. SwiGLU: silu(gate) * up
1290        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1291            B::sync(ctx);
1292            Some(std::time::Instant::now())
1293        } else {
1294            None
1295        };
1296        B::fused_silu_mul_split(
1297            ctx,
1298            &self.scratch.gate_up_out,
1299            &mut self.scratch.silu_out,
1300            tokens,
1301            im,
1302        );
1303        if let Some(t0) = _t0 {
1304            B::sync(ctx);
1305            OTHER_TIME_US.fetch_add(
1306                t0.elapsed().as_micros() as u64,
1307                std::sync::atomic::Ordering::Relaxed,
1308            );
1309            OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1310        }
1311
1312        // 12. Down projection
1313        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1314            B::sync(ctx);
1315            Some(std::time::Instant::now())
1316        } else {
1317            None
1318        };
1319        layer.down_proj.forward(
1320            ctx,
1321            &self.scratch.silu_out,
1322            &mut self.scratch.mlp_out,
1323            tokens,
1324        );
1325        if let Some(t0) = _t0 {
1326            B::sync(ctx);
1327            MATMUL_TIME_US.fetch_add(
1328                t0.elapsed().as_micros() as u64,
1329                std::sync::atomic::Ordering::Relaxed,
1330            );
1331            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1332        }
1333
1334        // 13. Final residual add
1335        let _t0 = if std::env::var("FERRUM_DECODE_OP_PROFILE").is_ok() {
1336            B::sync(ctx);
1337            Some(std::time::Instant::now())
1338        } else {
1339            None
1340        };
1341        B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1342        if let Some(t0) = _t0 {
1343            B::sync(ctx);
1344            OTHER_TIME_US.fetch_add(
1345                t0.elapsed().as_micros() as u64,
1346                std::sync::atomic::Ordering::Relaxed,
1347            );
1348            OTHER_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1349        }
1350    }
1351
1352    /// Multi-position decode-verify: run one forward pass over `tokens`
1353    /// starting at the cache's current end position, write their K/V
1354    /// into the KV cache, and return logits for ALL `tokens.len()`
1355    /// positions as a flat `Vec<f32>` of length `seq_len * vocab_size`.
1356    ///
1357    /// Used by speculative decoding: target receives
1358    /// `[last_token, draft_0, ..., draft_{N-1}]` (N+1 inputs) and produces
1359    /// N+1 logit rows in a single forward instead of N+1 sequential
1360    /// decode() calls. Positions are implicit — the model looks up
1361    /// `pos_offset = cache.len` the same way prefill_internal does, so
1362    /// chunked prefill semantics carry over for free.
1363    pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1364        let seq_len = tokens.len();
1365        assert!(seq_len > 0, "forward_verify called with empty tokens");
1366        self.ensure_scratch(seq_len);
1367        self.ensure_kv(cache_id);
1368
1369        let h = self.cfg.hidden_size;
1370        let vocab = self.cfg.vocab_size;
1371
1372        let pos_offset = self
1373            .kv_caches
1374            .get(cache_id)
1375            .and_then(|layers| layers.first())
1376            .map(|c| c.len)
1377            .unwrap_or(0);
1378
1379        let mut ctx = B::new_context();
1380        let mut residual = self
1381            .scratch
1382            .residual
1383            .take()
1384            .expect("scratch residual missing (previous call didn't restore)");
1385
1386        let embed = self
1387            .embed
1388            .as_ref()
1389            .expect("forward_verify called on backbone-only model (no embed)");
1390        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1391
1392        for li in 0..self.cfg.num_layers {
1393            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1394        }
1395
1396        // RMSNorm on ALL seq_len positions (prefill_internal only norms
1397        // the last one; verify needs the full grid).
1398        B::rms_norm(
1399            &mut ctx,
1400            &residual,
1401            &self.final_norm_w,
1402            self.cfg.rms_norm_eps,
1403            &mut self.scratch.norm_out,
1404            seq_len,
1405            h,
1406        );
1407
1408        // LM head applied to all positions → `seq_len * vocab` logits.
1409        // Reuses the existing `batch_logits` scratch (sized max_tokens *
1410        // vocab) so no extra allocation.
1411        let lm_head = self
1412            .lm_head
1413            .as_ref()
1414            .expect("forward_verify called on backbone-only model (no lm_head)");
1415        lm_head.forward(
1416            &mut ctx,
1417            &self.scratch.norm_out,
1418            &mut self.scratch.batch_logits,
1419            seq_len,
1420        );
1421
1422        B::sync(&mut ctx);
1423        self.scratch.residual = Some(residual);
1424        B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1425    }
1426
1427    /// Prefill: process `tokens` prompt tokens in a single batch, return
1428    /// `[vocab_size]` logits for the last position.
1429    ///
1430    /// Supports incremental prefill: if the KV cache for `cache_id` already
1431    /// contains earlier tokens, the new chunk's positions are computed as
1432    /// `[kv_len, kv_len + tokens.len())` so RoPE and causal masking stay
1433    /// aligned. Used by the engine's chunked-prefill path.
1434    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1435        let seq_len = tokens.len();
1436        assert!(seq_len > 0, "prefill called with empty token list");
1437        self.ensure_scratch(seq_len);
1438        self.ensure_kv(cache_id);
1439
1440        // Starting position for this chunk — 0 for a fresh prefill, kv_len
1441        // for the second+ chunk of a split prefill.
1442        let pos_offset = self
1443            .kv_caches
1444            .get(cache_id)
1445            .and_then(|layers| layers.first())
1446            .map(|c| c.len)
1447            .unwrap_or(0);
1448
1449        let h = self.cfg.hidden_size;
1450        let vocab = self.cfg.vocab_size;
1451        let mut ctx = B::new_context();
1452
1453        // Move `residual` out of `scratch` to work around the borrow checker:
1454        // `forward_layer` re-borrows `&mut self` to reach `self.layers` /
1455        // `self.kv_caches`, which would conflict with an outstanding
1456        // `&mut self.scratch.residual`. Use Option::take to move it out
1457        // (no placeholder alloc → no transient cuMemFreeAsync that could
1458        // corrupt stream pool state after graph ops on Blackwell).
1459        let mut residual = self
1460            .scratch
1461            .residual
1462            .take()
1463            .expect("scratch residual missing (previous call didn't restore)");
1464        let embed = self
1465            .embed
1466            .as_ref()
1467            .expect("prefill_internal called on backbone-only model (no embed)");
1468        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1469
1470        let prefill_profile = std::env::var("FERRUM_PREFILL_OP_PROFILE").is_ok();
1471        let prefill_t0 = if prefill_profile {
1472            B::sync(&mut ctx);
1473            Some(std::time::Instant::now())
1474        } else {
1475            None
1476        };
1477
1478        for li in 0..self.cfg.num_layers {
1479            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1480        }
1481
1482        if let Some(t0) = prefill_t0 {
1483            B::sync(&mut ctx);
1484            let total_us = t0.elapsed().as_micros() as u64;
1485            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1486            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1487            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1488            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1489            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1490            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1491            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1492            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1493            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1494            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1495            eprintln!(
1496                "[prefill-profile] tokens={} layers total={} ms",
1497                seq_len,
1498                total_us / 1000
1499            );
1500            let bucket = |label: &str, n: u64, us: u64| {
1501                if n > 0 {
1502                    eprintln!(
1503                        "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
1504                        n,
1505                        us / 1000,
1506                        us / n
1507                    );
1508                }
1509            };
1510            bucket("flash_attn", attn_n, attn_us);
1511            bucket("qk_norm_rope", qkr_n, qkr_us);
1512            bucket("matmuls", mm_n, mm_us);
1513            bucket("norms", norm_n, norm_us);
1514            bucket("other", other_n, other_us);
1515        }
1516
1517        // Take the last token's hidden state: residual[(seq_len-1)*h .. seq_len*h]
1518        B::copy_slice(
1519            &mut ctx,
1520            &residual,
1521            (seq_len - 1) * h,
1522            &mut self.scratch.last_hidden,
1523            0,
1524            h,
1525        );
1526
1527        // Final RMSNorm on the last hidden.
1528        B::rms_norm(
1529            &mut ctx,
1530            &self.scratch.last_hidden,
1531            &self.final_norm_w,
1532            self.cfg.rms_norm_eps,
1533            &mut self.scratch.last_normed,
1534            1,
1535            h,
1536        );
1537
1538        // LM head (m=1 — triggers GEMV on MetalBackend).
1539        let lm_head = self
1540            .lm_head
1541            .as_ref()
1542            .expect("prefill_internal called on backbone-only model (no lm_head)");
1543        lm_head.forward(
1544            &mut ctx,
1545            &self.scratch.last_normed,
1546            &mut self.scratch.logits,
1547            1,
1548        );
1549
1550        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
1551        // buffer's CPU pointer without flushing the command buffer, so the
1552        // GPU must complete all pending work first or we read stale/random
1553        // data. CUDA's to_vec does an internal stream.synchronize, making
1554        // the call redundant there (~50µs/step cost), but correctness on
1555        // Metal requires the explicit flush here.
1556        B::sync(&mut ctx);
1557
1558        // Restore residual into scratch for reuse on the next call.
1559        self.scratch.residual = Some(residual);
1560
1561        B::to_vec(&self.scratch.logits, vocab)
1562    }
1563
1564    /// Decode: process 1 token at position `pos`, return `[vocab_size]` logits.
1565    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1566        self.ensure_scratch(1);
1567        self.ensure_kv(cache_id);
1568
1569        let h = self.cfg.hidden_size;
1570        let vocab = self.cfg.vocab_size;
1571
1572        // Context creation is cheap (CUDA reuses the process-global stream).
1573        // The captured graph lives in a process-global slot, not on ctx.
1574        let mut ctx = B::new_context();
1575
1576        // Graph capture is opt-in via FERRUM_CUDA_GRAPH=1. Replay is currently
1577        // single-request-only on Blackwell + CUDA 12.8 (see
1578        // docs/phase-e-cuda-status.md). In pure eager mode, we skip the
1579        // per-step device-state memcpy_htod trio entirely.
1580        const GRAPH_WARMUP: usize = 3;
1581        let graph_enabled = std::env::var("FERRUM_CUDA_GRAPH").is_ok();
1582
1583        if graph_enabled {
1584            // Refresh device-side dynamic state (token/pos/kv_len) before
1585            // replay — captured graph reads these from device buffers.
1586            B::set_decode_state(&mut ctx, token, pos);
1587
1588            // Fast path: graph replay (if available).
1589            match B::replay_last_graph(&mut ctx) {
1590                Ok(true) => {
1591                    B::sync(&mut ctx);
1592                    return B::to_vec(&self.scratch.logits, vocab);
1593                }
1594                Ok(false) => { /* no graph yet, fall through to eager */ }
1595                Err(_) => { /* backend error or unsupported, eager */ }
1596            }
1597        }
1598
1599        let should_capture =
1600            graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
1601
1602        if should_capture {
1603            B::set_dev_state_mode(&mut ctx, true);
1604            if B::begin_graph_capture(&mut ctx).is_err() {
1605                self.graph_capture_failed = true;
1606                B::set_dev_state_mode(&mut ctx, false);
1607            }
1608        }
1609
1610        // Eager forward (records into graph if capture is active).
1611        // mem::replace needs a placeholder. B::alloc(0) was our choice but
1612        // cuMemAllocFromPoolAsync(stream, 0) can return CUDA_ERROR_INVALID_VALUE
1613        // on Blackwell after graph replay corrupts the pool state. Size-1 is
1614        // always valid and costs 2 bytes of transient VRAM per decode step.
1615        let mut residual = self
1616            .scratch
1617            .residual
1618            .take()
1619            .expect("scratch residual missing (previous call didn't restore)");
1620        let embed = self
1621            .embed
1622            .as_ref()
1623            .expect("decode_internal called on backbone-only model (no embed)");
1624        B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
1625
1626        // Per-layer wall-time profile (env-gated, off by default — adds
1627        // a B::sync between layers which serializes the pipeline). Helps
1628        // localise non-matmul bottlenecks during perf work.
1629        let layer_profile = std::env::var("FERRUM_DECODE_LAYER_PROFILE").is_ok();
1630        let mut layer_times = if layer_profile {
1631            Some(Vec::with_capacity(self.cfg.num_layers))
1632        } else {
1633            None
1634        };
1635
1636        for li in 0..self.cfg.num_layers {
1637            if layer_profile {
1638                let t0 = std::time::Instant::now();
1639                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1640                B::sync(&mut ctx);
1641                let elapsed_us = t0.elapsed().as_micros() as u64;
1642                if let Some(v) = layer_times.as_mut() {
1643                    v.push(elapsed_us);
1644                }
1645            } else {
1646                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1647            }
1648        }
1649        if let Some(times) = layer_times.take() {
1650            let sum: u64 = times.iter().sum();
1651            let avg = sum / times.len() as u64;
1652            let mn = *times.iter().min().unwrap_or(&0);
1653            let mx = *times.iter().max().unwrap_or(&0);
1654            eprintln!(
1655                "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
1656                times.len(),
1657                sum / 1000,
1658                avg,
1659                mn,
1660                mx
1661            );
1662            for (i, t) in times.iter().enumerate() {
1663                eprint!("L{i}={}ms ", t / 1000);
1664                if (i + 1) % 6 == 0 {
1665                    eprintln!();
1666                }
1667            }
1668            eprintln!();
1669            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1670            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1671            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1672            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1673            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1674            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1675            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1676            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1677            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1678            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1679            eprintln!(
1680                "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
1681                attn_n,
1682                attn_us / 1000,
1683                if attn_n > 0 { attn_us / attn_n } else { 0 }
1684            );
1685            eprintln!(
1686                "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
1687                qkr_n,
1688                qkr_us / 1000,
1689                if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
1690            );
1691            eprintln!(
1692                "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
1693                mm_n,
1694                mm_us / 1000,
1695                if mm_n > 0 { mm_us / mm_n } else { 0 }
1696            );
1697            eprintln!(
1698                "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
1699                norm_n,
1700                norm_us / 1000,
1701                if norm_n > 0 { norm_us / norm_n } else { 0 }
1702            );
1703            eprintln!(
1704                "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
1705                other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
1706            );
1707        }
1708
1709        B::rms_norm(
1710            &mut ctx,
1711            &residual,
1712            &self.final_norm_w,
1713            self.cfg.rms_norm_eps,
1714            &mut self.scratch.last_normed,
1715            1,
1716            h,
1717        );
1718
1719        let lm_head = self
1720            .lm_head
1721            .as_ref()
1722            .expect("decode_internal called on backbone-only model (no lm_head)");
1723        lm_head.forward(
1724            &mut ctx,
1725            &self.scratch.last_normed,
1726            &mut self.scratch.logits,
1727            1,
1728        );
1729
1730        if should_capture && !self.graph_capture_failed {
1731            if B::end_graph_capture(&mut ctx).is_err() {
1732                self.graph_capture_failed = true;
1733            } else {
1734                // Stream capture mode RECORDS ops into the graph without
1735                // executing them. scratch.logits still holds the previous
1736                // step's value. Replay the just-captured graph once to
1737                // actually execute and produce this step's logits. Without
1738                // this, the capture step's to_vec returns stale logits,
1739                // yielding a 1-token offset in the generated sequence.
1740                if B::replay_last_graph(&mut ctx).is_err() {
1741                    self.graph_capture_failed = true;
1742                }
1743            }
1744            B::set_dev_state_mode(&mut ctx, false);
1745        } else {
1746            self.graph_warmup += 1;
1747        }
1748
1749        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
1750        // buffer's CPU pointer without flushing the command buffer, so the
1751        // GPU must complete all pending work first or we read stale/random
1752        // data. CUDA's to_vec does an internal stream.synchronize, making
1753        // the call redundant there (~50µs/step cost), but correctness on
1754        // Metal requires the explicit flush here.
1755        B::sync(&mut ctx);
1756        self.scratch.residual = Some(residual);
1757
1758        B::to_vec(&self.scratch.logits, vocab)
1759    }
1760
1761    /// Prefill with pre-computed embeddings instead of token IDs.
1762    ///
1763    /// Used by models that embed inputs outside the LLM (e.g. Qwen3-TTS
1764    /// mixes text-embedding + codec-embedding before feeding the LM).
1765    /// Skips `final_norm` + `lm_head`; returns the last position's pre-norm
1766    /// hidden state. Caller applies its own output head.
1767    ///
1768    /// `embeds` is row-major `[seq_len * hidden_size]`, f32.
1769    pub fn prefill_from_embeds(
1770        &mut self,
1771        cache_id: &str,
1772        embeds: &[f32],
1773        seq_len: usize,
1774    ) -> Vec<f32> {
1775        let h = self.cfg.hidden_size;
1776        assert_eq!(
1777            embeds.len(),
1778            seq_len * h,
1779            "embeds length {} != seq_len * hidden_size {}",
1780            embeds.len(),
1781            seq_len * h
1782        );
1783        assert!(seq_len > 0, "prefill_from_embeds called with zero length");
1784
1785        self.ensure_scratch(seq_len);
1786        self.ensure_kv(cache_id);
1787
1788        let mut ctx = B::new_context();
1789        let mut residual = self
1790            .scratch
1791            .residual
1792            .take()
1793            .expect("scratch residual missing (previous call didn't restore)");
1794
1795        // Upload embeds → residual[0 .. seq_len*h].
1796        let embed_buf = B::from_slice(embeds);
1797        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1798
1799        for li in 0..self.cfg.num_layers {
1800            self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1801        }
1802
1803        B::copy_slice(
1804            &mut ctx,
1805            &residual,
1806            (seq_len - 1) * h,
1807            &mut self.scratch.last_hidden,
1808            0,
1809            h,
1810        );
1811        B::sync(&mut ctx);
1812        self.scratch.residual = Some(residual);
1813        B::to_vec(&self.scratch.last_hidden, h)
1814    }
1815
1816    /// Decode with a single pre-computed embedding (shape `[hidden]`).
1817    /// Returns the pre-norm hidden state for the position `pos`. Caller
1818    /// applies final norm + its own output head.
1819    pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1820        let h = self.cfg.hidden_size;
1821        assert_eq!(
1822            embed.len(),
1823            h,
1824            "embed length {} != hidden_size {}",
1825            embed.len(),
1826            h
1827        );
1828
1829        self.ensure_scratch(1);
1830        self.ensure_kv(cache_id);
1831
1832        let mut ctx = B::new_context();
1833        let mut residual = self
1834            .scratch
1835            .residual
1836            .take()
1837            .expect("scratch residual missing (previous call didn't restore)");
1838
1839        let embed_buf = B::from_slice(embed);
1840        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1841
1842        for li in 0..self.cfg.num_layers {
1843            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1844        }
1845
1846        B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1847        B::sync(&mut ctx);
1848        self.scratch.residual = Some(residual);
1849        B::to_vec(&self.scratch.last_hidden, h)
1850    }
1851
1852    /// Variant of `prefill_from_embeds` that applies `final_norm` to every
1853    /// position and returns the whole `[seq_len * hidden_size]` vector.
1854    /// Accepts `pos_offset` so callers can continue an existing sequence
1855    /// (e.g. Qwen3-TTS voice-clone: one prefill for the role prefix, a
1856    /// follow-up prefill for the reference-audio ICL block, then
1857    /// autoregressive decoding — all against the same KV cache).
1858    ///
1859    /// Used by TTS where `forward_step` in the candle-based wrapper is
1860    /// expected to return **post-norm all-positions** hidden state so
1861    /// `codec_head` can be applied on candle side.
1862    pub fn prefill_all_post_norm(
1863        &mut self,
1864        cache_id: &str,
1865        embeds: &[f32],
1866        seq_len: usize,
1867        pos_offset: usize,
1868    ) -> Vec<f32> {
1869        let h = self.cfg.hidden_size;
1870        assert_eq!(
1871            embeds.len(),
1872            seq_len * h,
1873            "embeds length {} != seq_len * hidden_size {}",
1874            embeds.len(),
1875            seq_len * h
1876        );
1877        assert!(seq_len > 0);
1878
1879        self.ensure_scratch(seq_len);
1880        self.ensure_kv(cache_id);
1881
1882        let mut ctx = B::new_context();
1883        let mut residual = self
1884            .scratch
1885            .residual
1886            .take()
1887            .expect("scratch residual missing (previous call didn't restore)");
1888
1889        let embed_buf = B::from_slice(embeds);
1890        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1891
1892        for li in 0..self.cfg.num_layers {
1893            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1894        }
1895
1896        // Apply final_norm over all seq_len positions → scratch.norm_out.
1897        B::rms_norm(
1898            &mut ctx,
1899            &residual,
1900            &self.final_norm_w,
1901            self.cfg.rms_norm_eps,
1902            &mut self.scratch.norm_out,
1903            seq_len,
1904            h,
1905        );
1906        B::sync(&mut ctx);
1907        self.scratch.residual = Some(residual);
1908        B::to_vec(&self.scratch.norm_out, seq_len * h)
1909    }
1910
1911    /// Decode-side companion to `prefill_all_post_norm`. Runs a single-token
1912    /// decode step at `pos`, applies `final_norm`, and returns the post-norm
1913    /// hidden state `[hidden_size]`.
1914    pub fn decode_post_norm_from_embed(
1915        &mut self,
1916        cache_id: &str,
1917        embed: &[f32],
1918        pos: u32,
1919    ) -> Vec<f32> {
1920        let h = self.cfg.hidden_size;
1921        assert_eq!(embed.len(), h);
1922
1923        self.ensure_scratch(1);
1924        self.ensure_kv(cache_id);
1925
1926        let mut ctx = B::new_context();
1927        let mut residual = self
1928            .scratch
1929            .residual
1930            .take()
1931            .expect("scratch residual missing (previous call didn't restore)");
1932
1933        let embed_buf = B::from_slice(embed);
1934        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1935
1936        for li in 0..self.cfg.num_layers {
1937            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1938        }
1939
1940        B::rms_norm(
1941            &mut ctx,
1942            &residual,
1943            &self.final_norm_w,
1944            self.cfg.rms_norm_eps,
1945            &mut self.scratch.last_normed,
1946            1,
1947            h,
1948        );
1949        B::sync(&mut ctx);
1950        self.scratch.residual = Some(residual);
1951        B::to_vec(&self.scratch.last_normed, h)
1952    }
1953
1954    /// Batched decode: process M concurrent requests at potentially different
1955    /// positions in one forward pass. GEMM-heavy ops (qkv_proj, o_proj,
1956    /// gate_up, down) run with m=M for natural batching; rope + KV append +
1957    /// attention loop per-item (each has its own KV cache at a different
1958    /// kv_len, and potentially different pos).
1959    ///
1960    /// Returns M logit vectors in the same order as `batch`.
1961    pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1962        let m = batch.len();
1963        if m == 0 {
1964            return Vec::new();
1965        }
1966        if m == 1 {
1967            let (cid, tok, pos) = &batch[0];
1968            return vec![self.decode_internal(cid, *tok, *pos)];
1969        }
1970
1971        // Ensure all caches exist and scratch is sized for M tokens.
1972        for (cid, _, _) in batch {
1973            self.ensure_kv(cid);
1974        }
1975        self.ensure_scratch(m);
1976        // Phase 4b: when paged mode is on, ensure_kv has already
1977        // populated the batched scratch buffers (paged_batch_q etc.).
1978        // The forward path branches on `paged_pools.is_some()` inside
1979        // each layer.
1980
1981        let h = self.cfg.hidden_size;
1982        let vocab = self.cfg.vocab_size;
1983        let mut ctx = B::new_context();
1984
1985        // 0. Embed all M tokens into residual [M, H]
1986        let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1987        let mut residual = self
1988            .scratch
1989            .residual
1990            .take()
1991            .expect("scratch residual missing (previous call didn't restore)");
1992        let embed = self
1993            .embed
1994            .as_ref()
1995            .expect("decode_batch_internal called on backbone-only model (no embed)");
1996        B::embedding_lookup(&mut ctx, embed, &tokens, &mut residual, h);
1997
1998        // 1..num_layers: batched forward for each layer
1999        for li in 0..self.cfg.num_layers {
2000            self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m);
2001        }
2002
2003        // Final RMSNorm on [M, H] → norm_out [M, H]
2004        B::rms_norm(
2005            &mut ctx,
2006            &residual,
2007            &self.final_norm_w,
2008            self.cfg.rms_norm_eps,
2009            &mut self.scratch.norm_out,
2010            m,
2011            h,
2012        );
2013
2014        // LM head with m=M → batch_logits [M, vocab]
2015        let lm_head = self
2016            .lm_head
2017            .as_ref()
2018            .expect("decode_batch_internal called on backbone-only model (no lm_head)");
2019        lm_head.forward(
2020            &mut ctx,
2021            &self.scratch.norm_out,
2022            &mut self.scratch.batch_logits,
2023            m,
2024        );
2025
2026        // Sync before to_vec (Metal: no internal sync on buffer read).
2027        B::sync(&mut ctx);
2028        self.scratch.residual = Some(residual);
2029
2030        // Extract M logit vectors from the flat buffer.
2031        let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
2032        (0..m)
2033            .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
2034            .collect()
2035    }
2036
2037    /// One transformer layer over M items, GEMMs batched + per-item attention.
2038    fn forward_layer_batched_decode(
2039        &mut self,
2040        ctx: &mut B::Context,
2041        li: usize,
2042        batch: &[(String, u32, u32)],
2043        residual: &mut B::Buffer,
2044        m: usize,
2045    ) {
2046        let cfg = &self.cfg;
2047        let h = cfg.hidden_size;
2048        let nh = cfg.num_heads;
2049        let nkv = cfg.num_kv_heads;
2050        let hd = cfg.head_dim;
2051        let im = cfg.intermediate_size;
2052        let eps = cfg.rms_norm_eps;
2053        let q_dim = nh * hd;
2054        let kv_dim = nkv * hd;
2055
2056        let layer = &self.layers[li];
2057        let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
2058        let dummy_w = &layer.input_ln_w;
2059        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy_w);
2060        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy_w);
2061
2062        // 1. rms_norm [M, H]  → norm_out
2063        B::rms_norm(
2064            ctx,
2065            residual,
2066            &layer.input_ln_w,
2067            eps,
2068            &mut self.scratch.norm_out,
2069            m,
2070            h,
2071        );
2072
2073        // 2. qkv_proj (GEMM m=M): norm_out [M, H] → qkv_out [M, QKV]
2074        layer
2075            .qkv_proj
2076            .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
2077
2078        // ── Paged-KV batched path (Phase 4b) ──────────────────────────
2079        // When paged is on, we skip the contig split_qkv + per-item
2080        // qk_norm_rope + kv_append + flash_attention loop entirely.
2081        // Instead:
2082        //   1. Per item: split_qkv_norm_rope_into_paged_cache with
2083        //      qkv_byte_offset = i * qkv_stride * 4 reads item i's
2084        //      slice of qkv_out, writes K/V into the shared pool at
2085        //      its block_table-resolved position, and stores the
2086        //      RoPE'd Q at paged_batch_q[i * q_dim .. (i+1) * q_dim].
2087        //   2. Build batched block_tables [M, max_blocks_per_seq] +
2088        //      context_lens [M] host-side, write to scratch device
2089        //      buffers.
2090        //   3. Single paged_decode_attention(num_seqs=M) reads all M
2091        //      seqs' K/V via per-seq block_tables, writes to
2092        //      paged_batch_o.
2093        //   4. Per item: copy paged_batch_o[i] → attn_flat[i * q_dim].
2094        //
2095        // This is the "real" multi-seq decode — one heavy attention
2096        // dispatch covering all sequences instead of M sequential ones.
2097        if let Some(pools) = self.paged_pools.as_mut() {
2098            let pool_ptr = (
2099                &mut pools[li].0 as *mut B::Buffer,
2100                &mut pools[li].1 as *mut B::Buffer,
2101            );
2102            // SAFETY: pools allocated once; not concurrently mutated.
2103            let (pool_k, pool_v) = unsafe { (&mut *pool_ptr.0, &mut *pool_ptr.1) };
2104
2105            let qkv_stride = q_dim + 2 * kv_dim;
2106            let max_blocks_per_seq = self.scratch.paged_max_blocks_per_seq;
2107            let block_size = 16; // matches PAGED_BLOCK_SIZE in ensure_kv
2108
2109            // Step 1: per-item paged write. We collect cache_len + block_indices
2110            // up front for step 2. Note: this loop borrows self.kv_caches mutably
2111            // per iteration, so we extract the batched-write parameters first then
2112            // do the dispatches.
2113            let mut item_state: Vec<(u32, Vec<u32>)> = Vec::with_capacity(m);
2114            for (cache_id, _, _) in batch.iter() {
2115                let caches = self
2116                    .kv_caches
2117                    .get(cache_id)
2118                    .expect("ensure_kv must be called before forward_layer_batched");
2119                let cache = &caches[li];
2120                item_state.push((cache.len as u32, cache.paged_block_indices.clone()));
2121            }
2122
2123            // Take block_table buffer ptrs ahead of the dispatch loop —
2124            // we need both per-cache block_table (to write into) and
2125            // self.scratch.paged_batch_q (to write Q stacks into).
2126            let q_head_major_size_bytes = (q_dim * std::mem::size_of::<f32>()) as u64;
2127            let qkv_stride_bytes = (qkv_stride * std::mem::size_of::<f32>()) as u64;
2128            for (i, (cache_id, _, pos)) in batch.iter().enumerate() {
2129                let pos_i = *pos as usize;
2130                let caches = self
2131                    .kv_caches
2132                    .get(cache_id)
2133                    .expect("paged batched: cache not present");
2134                let cache = &caches[li];
2135                let bt = cache
2136                    .block_table
2137                    .as_ref()
2138                    .expect("paged batched: block_table missing");
2139                let cache_len_before = cache.len;
2140                let block_table_ref = bt as *const B::Buffer;
2141                // SAFETY: bt is read-only in the dispatch; we don't
2142                // mutate self.kv_caches between this raw deref and the
2143                // call.
2144                let bt_safe: &B::Buffer = unsafe { &*block_table_ref };
2145                B::split_qkv_norm_rope_into_paged_cache(
2146                    ctx,
2147                    &self.scratch.qkv_out,
2148                    (i as u64) * qkv_stride_bytes,
2149                    q_norm_w,
2150                    k_norm_w,
2151                    &self.rope.cos,
2152                    &self.rope.sin,
2153                    self.scratch
2154                        .paged_batch_q
2155                        .as_mut()
2156                        .expect("paged_batch_q missing"),
2157                    (i as u64) * q_head_major_size_bytes,
2158                    pool_k,
2159                    pool_v,
2160                    bt_safe,
2161                    1,
2162                    nh,
2163                    nkv,
2164                    hd,
2165                    pos_i,
2166                    eps,
2167                    qk_mode,
2168                    cache_len_before,
2169                    block_size,
2170                    max_blocks_per_seq,
2171                )
2172                .expect("paged batched write");
2173            }
2174
2175            // Step 2: bump cache.len and build the stacked block_tables +
2176            // context_lens host-side, then upload to device scratch.
2177            let mut stacked_bt: Vec<u32> = vec![0u32; m * max_blocks_per_seq];
2178            let mut stacked_cl: Vec<u32> = vec![0u32; m];
2179            for (i, (cache_id, _, _)) in batch.iter().enumerate() {
2180                let caches = self
2181                    .kv_caches
2182                    .get_mut(cache_id)
2183                    .expect("paged batched: cache not present");
2184                let cache = &mut caches[li];
2185                cache.len += 1;
2186                let len = cache.len as u32;
2187                stacked_cl[i] = len;
2188                let blocks = &cache.paged_block_indices;
2189                let n_to_copy = blocks.len().min(max_blocks_per_seq);
2190                stacked_bt[i * max_blocks_per_seq..i * max_blocks_per_seq + n_to_copy]
2191                    .copy_from_slice(&blocks[..n_to_copy]);
2192            }
2193            let bt_buf = self
2194                .scratch
2195                .paged_batch_block_tables
2196                .as_mut()
2197                .expect("paged_batch_block_tables missing");
2198            B::write_u32(ctx, bt_buf, &stacked_bt);
2199            let cl_buf = self
2200                .scratch
2201                .paged_batch_context_lens
2202                .as_mut()
2203                .expect("paged_batch_context_lens missing");
2204            B::write_u32(ctx, cl_buf, &stacked_cl);
2205
2206            // Step 3: one batched paged_decode_attention(num_seqs=m).
2207            let bt_ptr =
2208                self.scratch.paged_batch_block_tables.as_ref().unwrap() as *const B::Buffer;
2209            let cl_ptr =
2210                self.scratch.paged_batch_context_lens.as_ref().unwrap() as *const B::Buffer;
2211            let q_ptr = self.scratch.paged_batch_q.as_ref().unwrap() as *const B::Buffer;
2212            let o_ptr = self.scratch.paged_batch_o.as_mut().unwrap() as *mut B::Buffer;
2213            // SAFETY: the four scratch buffers above are not aliased
2214            // by anything else; we only deref while &mut self is held.
2215            let bt_safe = unsafe { &*bt_ptr };
2216            let cl_safe = unsafe { &*cl_ptr };
2217            let q_safe = unsafe { &*q_ptr };
2218            let o_safe = unsafe { &mut *o_ptr };
2219            B::paged_decode_attention(
2220                ctx,
2221                q_safe,
2222                pool_k,
2223                pool_v,
2224                o_safe,
2225                bt_safe,
2226                cl_safe,
2227                m,
2228                nh,
2229                nkv,
2230                hd,
2231                block_size,
2232                max_blocks_per_seq,
2233                1, // q_len
2234            )
2235            .expect("paged batched decode");
2236
2237            // Step 4: per-item copy paged_batch_o[i] → attn_flat[i * q_dim].
2238            // Both have q_dim floats per item; same head-major-equals-token-major
2239            // identity collapse used in the contig path.
2240            for i in 0..m {
2241                B::copy_slice(
2242                    ctx,
2243                    self.scratch.paged_batch_o.as_ref().unwrap(),
2244                    i * q_dim,
2245                    &mut self.scratch.attn_flat,
2246                    i * q_dim,
2247                    q_dim,
2248                );
2249            }
2250
2251            // Skip the contig split_qkv + per-item loop below.
2252            return self.forward_layer_batched_decode_post_attn(ctx, li, residual, m);
2253        }
2254
2255        // 3. split_qkv [M, QKV] → q_buf [M, Q], k_buf [M, KV], v_buf [M, KV]
2256        B::split_qkv(
2257            ctx,
2258            &self.scratch.qkv_out,
2259            &mut self.scratch.q_buf,
2260            &mut self.scratch.k_buf,
2261            &mut self.scratch.v_buf,
2262            m,
2263            q_dim,
2264            kv_dim,
2265        );
2266
2267        // 4-6. Per-item loop for rope + kv_append + attention.
2268        //      Each item has its own cache_id + pos + kv_len.
2269        for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
2270            let pos_i = *pos as usize;
2271
2272            // Extract item i's Q/K/V from batched buffers.
2273            B::copy_slice(
2274                ctx,
2275                &self.scratch.q_buf,
2276                i * q_dim,
2277                &mut self.scratch.q_single,
2278                0,
2279                q_dim,
2280            );
2281            B::copy_slice(
2282                ctx,
2283                &self.scratch.k_buf,
2284                i * kv_dim,
2285                &mut self.scratch.k_single,
2286                0,
2287                kv_dim,
2288            );
2289            B::copy_slice(
2290                ctx,
2291                &self.scratch.v_buf,
2292                i * kv_dim,
2293                &mut self.scratch.v_single,
2294                0,
2295                kv_dim,
2296            );
2297
2298            // qk_norm_rope with tokens=1, per-item pos.
2299            B::qk_norm_rope(
2300                ctx,
2301                &self.scratch.q_single,
2302                q_norm_w,
2303                &self.rope.cos,
2304                &self.rope.sin,
2305                &mut self.scratch.q_head_major_single,
2306                1,
2307                nh,
2308                hd,
2309                pos_i,
2310                eps,
2311                qk_mode,
2312            );
2313            B::qk_norm_rope(
2314                ctx,
2315                &self.scratch.k_single,
2316                k_norm_w,
2317                &self.rope.cos,
2318                &self.rope.sin,
2319                &mut self.scratch.k_head_major_single,
2320                1,
2321                nkv,
2322                hd,
2323                pos_i,
2324                eps,
2325                qk_mode,
2326            );
2327            B::qk_norm_rope(
2328                ctx,
2329                &self.scratch.v_single,
2330                dummy_w,
2331                &self.rope.cos,
2332                &self.rope.sin,
2333                &mut self.scratch.v_head_major_single,
2334                1,
2335                nkv,
2336                hd,
2337                pos_i,
2338                eps,
2339                0,
2340            );
2341
2342            // KV append + attention for item i's cache.
2343            let caches = self
2344                .kv_caches
2345                .get_mut(cache_id)
2346                .expect("ensure_kv must be called before forward_layer_batched");
2347            let cache = &mut caches[li];
2348            B::kv_cache_append_head_major(
2349                ctx,
2350                &mut cache.k,
2351                &mut cache.v,
2352                cache.len,
2353                cache.capacity,
2354                &self.scratch.k_head_major_single,
2355                &self.scratch.v_head_major_single,
2356                1,
2357                nkv,
2358                hd,
2359            );
2360            cache.len += 1;
2361            let kv_len = cache.len;
2362            let kv_stride = cache.capacity;
2363
2364            let attn_cfg = ferrum_kernels::backend::AttnConfig {
2365                num_heads: nh,
2366                num_kv_heads: nkv,
2367                head_dim: hd,
2368                causal: true,
2369                scale: 1.0 / (hd as f32).sqrt(),
2370                kv_seq_stride: kv_stride,
2371                sliding_window: cfg.sliding_window,
2372            };
2373            B::flash_attention(
2374                ctx,
2375                &self.scratch.q_head_major_single,
2376                &cache.k,
2377                &cache.v,
2378                &mut self.scratch.attn_head_major_single,
2379                1,
2380                1,
2381                kv_len,
2382                pos_i,
2383                &attn_cfg,
2384            );
2385
2386            // Untranspose head-major → token-major + inject into batched
2387            // attn_flat[M, Q]. For tokens=1 the head-major and
2388            // token-major layouts are byte-identical (both flat to
2389            // [heads * head_dim] = [q_dim] floats), so we skip the
2390            // transpose dispatch entirely and copy attn_head_major_single
2391            // straight into the per-item slot. Saves 1 dispatch per
2392            // batch-item per layer.
2393            B::copy_slice(
2394                ctx,
2395                &self.scratch.attn_head_major_single,
2396                0,
2397                &mut self.scratch.attn_flat,
2398                i * q_dim,
2399                q_dim,
2400            );
2401        }
2402
2403        self.forward_layer_batched_decode_post_attn(ctx, li, residual, m);
2404    }
2405
2406    fn forward_layer_batched_decode_post_attn(
2407        &mut self,
2408        ctx: &mut B::Context,
2409        li: usize,
2410        residual: &mut B::Buffer,
2411        m: usize,
2412    ) {
2413        let cfg = &self.cfg;
2414        let h = cfg.hidden_size;
2415        let im = cfg.intermediate_size;
2416        let eps = cfg.rms_norm_eps;
2417        let layer = &self.layers[li];
2418
2419        // 7. o_proj (GEMM m=M): attn_flat [M, Q] → o_proj_out [M, H]
2420        layer.o_proj.forward(
2421            ctx,
2422            &self.scratch.attn_flat,
2423            &mut self.scratch.o_proj_out,
2424            m,
2425        );
2426
2427        // 8. Fused residual add + post-attention RMSNorm.
2428        B::fused_add_rms_norm(
2429            ctx,
2430            residual,
2431            &self.scratch.o_proj_out,
2432            &layer.post_ln_w,
2433            eps,
2434            &mut self.scratch.norm_out,
2435            m,
2436            h,
2437        );
2438
2439        // 9. gate_up_proj (GEMM m=M)
2440        layer.gate_up_proj.forward(
2441            ctx,
2442            &self.scratch.norm_out,
2443            &mut self.scratch.gate_up_out,
2444            m,
2445        );
2446
2447        // 10. SwiGLU
2448        B::fused_silu_mul_split(
2449            ctx,
2450            &self.scratch.gate_up_out,
2451            &mut self.scratch.silu_out,
2452            m,
2453            im,
2454        );
2455
2456        // 11. down_proj (GEMM m=M)
2457        layer
2458            .down_proj
2459            .forward(ctx, &self.scratch.silu_out, &mut self.scratch.mlp_out, m);
2460
2461        // 12. Residual add
2462        B::add_inplace(ctx, residual, &self.scratch.mlp_out, m * h);
2463    }
2464}
2465
2466impl<B: Backend> DecoderOnlyLLM for LlamaFamilyModel<B> {
2467    fn config(&self) -> &LlmRuntimeConfig {
2468        &self.runtime_cfg
2469    }
2470
2471    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2472        // Eager scratch + KV cache grow + a 1-token forward warmup —
2473        // see the Qwen3MoeModel::prepare comment for the rationale.
2474        // Without the warmup forward, the first real prefill pays
2475        // Metal pipeline first-bind costs inside the timer window.
2476        self.ensure_scratch(max_tokens);
2477        self.ensure_kv(cache_id);
2478
2479        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2480        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2481        // Release via the same path as `release` so paged blocks
2482        // return to the shared allocator. Otherwise warmup leaks
2483        // 256 blocks (the full per-seq quota) into the pool.
2484        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2485            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2486                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2487                if let Some(c0) = caches.first() {
2488                    if !c0.paged_block_indices.is_empty() {
2489                        alloc.free(&c0.paged_block_indices);
2490                    }
2491                }
2492                for c in caches.iter_mut() {
2493                    c.paged_block_indices.clear();
2494                }
2495            }
2496            self.kv_free_pool.push(caches);
2497        }
2498    }
2499
2500    fn kv_capacity(&self) -> usize {
2501        // Mirror the bound `ensure_kv` will use when allocating the cache.
2502        let model_max = self.cfg.max_seq_len;
2503        const DEFAULT_KV_CAPACITY: usize = 512;
2504        std::env::var("FERRUM_KV_CAPACITY")
2505            .ok()
2506            .and_then(|s| s.parse::<usize>().ok())
2507            .map(|cap| cap.min(model_max))
2508            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
2509    }
2510
2511    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2512        self.prefill_internal(cache_id, tokens)
2513    }
2514
2515    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2516        self.decode_internal(cache_id, token, pos)
2517    }
2518
2519    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2520        self.decode_batch_internal(batch)
2521    }
2522
2523    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2524        // Delegate to the inherent implementation on LlamaFamilyModel.
2525        LlamaFamilyModel::<B>::forward_verify(self, cache_id, tokens)
2526    }
2527
2528    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2529        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2530            for c in caches.iter_mut() {
2531                if new_len < c.len {
2532                    c.len = new_len;
2533                }
2534            }
2535        }
2536        // Captured graph expects a specific cache layout; roll it back too.
2537        let mut ctx = B::new_context();
2538        B::reset_graph(&mut ctx);
2539        self.graph_warmup = 0;
2540        self.graph_capture_failed = false;
2541    }
2542
2543    fn release(&mut self, cache_id: &str) {
2544        // Sync + drop graph BEFORE touching cache buffers. The graph was
2545        // actively running replays up to this point; destroying the graph
2546        // while the allocator pool still has in-flight references from the
2547        // graph's kernels corrupts stream state. Sync first to drain, then
2548        // destroy graph, then sync again to ensure cleanup completes.
2549        let mut ctx = B::new_context();
2550        B::sync(&mut ctx);
2551        B::reset_graph(&mut ctx);
2552        B::sync(&mut ctx);
2553        self.graph_warmup = 0;
2554        self.graph_capture_failed = false;
2555
2556        // Return the cache's buffers to the free pool instead of dropping.
2557        // Pointers stay stable for the next request's captured graph.
2558        // Paged mode: also free the cache's blocks back to the shared
2559        // allocator so other sequences can reuse them.
2560        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2561            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2562                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2563                // All caches share the same block_indices set (one per
2564                // cache_id), so freeing once via the first layer's cache
2565                // is enough.
2566                if let Some(c0) = caches.first() {
2567                    if !c0.paged_block_indices.is_empty() {
2568                        alloc.free(&c0.paged_block_indices);
2569                    }
2570                }
2571                // Clear the host-side mirror on every layer so a
2572                // free-pool reuse re-allocates fresh blocks.
2573                for c in caches.iter_mut() {
2574                    c.paged_block_indices.clear();
2575                }
2576            }
2577            self.kv_free_pool.push(caches);
2578        }
2579    }
2580
2581    fn reset(&mut self) {
2582        // Hard reset: drop all caches AND the pool, invalidate graph.
2583        let mut ctx = B::new_context();
2584        B::sync(&mut ctx);
2585        B::reset_graph(&mut ctx);
2586        B::sync(&mut ctx);
2587        self.graph_warmup = 0;
2588        self.graph_capture_failed = false;
2589        self.kv_caches.clear();
2590        self.kv_free_pool.clear();
2591    }
2592}
2593
2594fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2595    let hd = cfg.head_dim;
2596    let half = hd / 2;
2597    let max = cfg.max_seq_len;
2598    let mut cos = vec![0.0f32; max * half];
2599    let mut sin = vec![0.0f32; max * half];
2600    for pos in 0..max {
2601        for i in 0..half {
2602            let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
2603            let angle = pos as f64 * freq;
2604            cos[pos * half + i] = angle.cos() as f32;
2605            sin[pos * half + i] = angle.sin() as f32;
2606        }
2607    }
2608    RopeCache {
2609        cos: B::from_slice(&cos),
2610        sin: B::from_slice(&sin),
2611    }
2612}