Skip to main content

ferrum_models/models/
llama_family.rs

1//! Llama-family decoder model as explicit code.
2//!
3//! Covers all "standard Llama-style GQA + SwiGLU + RoPE" decoders:
4//!   - Llama / Llama-2 / Llama-3  (no QK-norm)
5//!   - Qwen2 / Qwen2.5            (no QK-norm, structurally Llama)
6//!   - Qwen3                      (QK-norm per head, larger rope_theta)
7//!   - Mistral                    (sliding-window attention — not yet
8//!                                 supported on the forward path; will
9//!                                 require an `AttnConfig.sliding_window`
10//!                                 field + shader support in Phase D)
11//!
12//! Variant differences are controlled by `LlamaFamilyConfig::has_qk_norm`
13//! and RoPE theta. Weight loading accepts both fused (`qkv_proj`,
14//! `gate_up_proj`) and split (`q_proj`+`k_proj`+`v_proj`,
15//! `gate_proj`+`up_proj`) projection layouts — the loader (e.g.
16//! `CandleVarBuilderLoader`) fuses split weights on load so model code
17//! sees a uniform `qkv_proj` / `gate_up_proj` Linear.
18
19use std::collections::HashMap;
20use std::sync::{atomic::AtomicU64, OnceLock};
21
22use ferrum_interfaces::kv_dtype::{KvFp16, KvInt8};
23use ferrum_kernels::backend::{
24    Backend, BackendGraph, BackendInt8KvOps, BackendMoeFused, BackendPagedKv, BackendQuantGguf,
25    BackendQuantMarlin, KvCache, KvLayer, LlmBackend, MoeLlmBackend, QuantLlmBackend,
26    MAX_LAYERS_FOR_GRAPH,
27};
28
29/// Graph cache key for the single-item decode path (`decode_internal`).
30/// Distinct from any `m_padded`-based key used by the batched path.
31pub(crate) const SINGLE_ITEM_GRAPH_KEY: u64 = 0;
32
33/// Diag counters for the batched graph dispatcher (replay vs eager).
34pub(crate) static BATCHED_GRAPH_REPLAY_COUNT: AtomicU64 = AtomicU64::new(0);
35pub(crate) static BATCHED_GRAPH_EAGER_COUNT: AtomicU64 = AtomicU64::new(0);
36
37pub(crate) static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
38pub(crate) static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
39pub(crate) static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
40pub(crate) static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
41pub(crate) static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
42pub(crate) static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
43pub(crate) static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
44pub(crate) static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
45pub(crate) static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
46pub(crate) static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
47use ferrum_quantization::{Linear, WeightLoader};
48use ferrum_types::Result;
49
50use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
51
52const DEFAULT_KV_CAPACITY: usize = 512;
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55struct LlamaFamilyRuntimeEnv {
56    kv_capacity: Option<usize>,
57    metal_paged_kv: Option<bool>,
58    paged_max_seqs: usize,
59    decode_op_profile: bool,
60    prefill_op_profile: bool,
61    cuda_graph: bool,
62    decode_layer_profile: bool,
63}
64
65impl LlamaFamilyRuntimeEnv {
66    fn from_env() -> Self {
67        Self::from_env_vars(std::env::vars())
68    }
69
70    fn from_env_vars<I, K, V>(vars: I) -> Self
71    where
72        I: IntoIterator<Item = (K, V)>,
73        K: AsRef<str>,
74        V: AsRef<str>,
75    {
76        let mut config = Self {
77            kv_capacity: None,
78            metal_paged_kv: None,
79            paged_max_seqs: 32,
80            decode_op_profile: false,
81            prefill_op_profile: false,
82            cuda_graph: false,
83            decode_layer_profile: false,
84        };
85        for (name, value) in vars {
86            let value = value.as_ref();
87            match name.as_ref() {
88                "FERRUM_KV_CAPACITY" => config.kv_capacity = value.parse::<usize>().ok(),
89                "FERRUM_METAL_PAGED_KV" => config.metal_paged_kv = Some(value != "0"),
90                "FERRUM_PAGED_MAX_SEQS" => {
91                    if let Ok(max_seqs) = value.parse::<usize>() {
92                        config.paged_max_seqs = max_seqs;
93                    }
94                }
95                "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
96                "FERRUM_PREFILL_OP_PROFILE" => config.prefill_op_profile = true,
97                "FERRUM_CUDA_GRAPH" => config.cuda_graph = true,
98                "FERRUM_DECODE_LAYER_PROFILE" => config.decode_layer_profile = true,
99                _ => {}
100            }
101        }
102        config
103    }
104
105    fn kv_capacity_for_model(&self, model_max: usize) -> usize {
106        self.kv_capacity
107            .map(|cap| cap.min(model_max))
108            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
109    }
110
111    fn paged_kv_enabled<B: BackendPagedKv>(&self) -> bool {
112        self.metal_paged_kv
113            .unwrap_or_else(|| B::supports_paged_kv())
114    }
115}
116
117fn llama_family_runtime_env() -> &'static LlamaFamilyRuntimeEnv {
118    static CONFIG: OnceLock<LlamaFamilyRuntimeEnv> = OnceLock::new();
119    CONFIG.get_or_init(LlamaFamilyRuntimeEnv::from_env)
120}
121
122/// Full Qwen3 architecture config (everything the model code needs, not just
123/// the engine-facing subset in `LlmRuntimeConfig`).
124#[derive(Clone, Debug, PartialEq)]
125pub struct LlamaFamilyConfig {
126    pub hidden_size: usize,
127    pub intermediate_size: usize,
128    pub num_heads: usize,
129    pub num_kv_heads: usize,
130    pub head_dim: usize,
131    pub num_layers: usize,
132    pub vocab_size: usize,
133    pub max_seq_len: usize,
134    pub rms_norm_eps: f32,
135    pub rope_theta: f64,
136    /// Whether the checkpoint has `q_norm` / `k_norm` per layer. All known
137    /// Qwen3 checkpoints do; some derivatives may strip them.
138    pub has_qk_norm: bool,
139    /// Sliding-window attention size. `0` disables (full causal).
140    /// Mistral v0.1 sets 4096; Mistral v0.2+ removed the limitation (0).
141    pub sliding_window: usize,
142}
143
144impl LlamaFamilyConfig {
145    pub fn to_runtime(&self) -> LlmRuntimeConfig {
146        LlmRuntimeConfig {
147            hidden_size: self.hidden_size,
148            num_layers: self.num_layers,
149            num_kv_heads: self.num_kv_heads,
150            head_dim: self.head_dim,
151            vocab_size: self.vocab_size,
152            max_seq_len: self.max_seq_len,
153        }
154    }
155
156    /// Build config from a `ModelDefinition`, shared field extraction.
157    /// Variant-specific constructors below set `has_qk_norm` and fall back
158    /// to different `rope_theta` defaults when the checkpoint doesn't set one.
159    fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
160        let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
161        let head_dim = def
162            .extra_params
163            .get("head_dim")
164            .and_then(|v| v.as_u64())
165            .map(|v| v as usize)
166            .unwrap_or(def.hidden_size / def.num_attention_heads);
167        // Mistral / Gemma: "sliding_window" may be null (v0.2+) or a positive
168        // integer (v0.1). Non-null value passes through; missing/null → 0.
169        let sliding_window = def
170            .extra_params
171            .get("sliding_window")
172            .and_then(|v| v.as_u64())
173            .map(|v| v as usize)
174            .unwrap_or(0);
175
176        LlamaFamilyConfigBase {
177            hidden_size: def.hidden_size,
178            intermediate_size: def.intermediate_size,
179            num_heads: def.num_attention_heads,
180            num_kv_heads,
181            head_dim,
182            num_layers: def.num_hidden_layers,
183            vocab_size: def.vocab_size,
184            max_seq_len: def.max_position_embeddings,
185            rms_norm_eps: def.norm_eps as f32,
186            rope_theta_opt: def.rope_theta,
187            sliding_window,
188        }
189    }
190
191    fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
192        Self {
193            hidden_size: b.hidden_size,
194            intermediate_size: b.intermediate_size,
195            num_heads: b.num_heads,
196            num_kv_heads: b.num_kv_heads,
197            head_dim: b.head_dim,
198            num_layers: b.num_layers,
199            vocab_size: b.vocab_size,
200            max_seq_len: b.max_seq_len,
201            rms_norm_eps: b.rms_norm_eps,
202            rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
203            has_qk_norm,
204            sliding_window: b.sliding_window,
205        }
206    }
207
208    /// Qwen3: QK-norm on, rope_theta default 1e6.
209    pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
210        Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
211    }
212
213    /// Llama / Llama-2 / Llama-3: no QK-norm; rope_theta varies by version
214    /// (10k for Llama-2, 500k for Llama-3.1+) — use the checkpoint value or
215    /// fall back to the most common modern value.
216    pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
217        Self::from_base(Self::from_def_base(def), 500_000.0, false)
218    }
219
220    /// Qwen2 / Qwen2.5: structurally Llama; no QK-norm; rope_theta default 1e6.
221    pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
222        Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
223    }
224
225    /// Mistral: no QK-norm; `rope_theta` commonly 10_000 (v0.1 / v0.2).
226    /// Picks up `sliding_window` from the checkpoint's config.json
227    /// (Mistral v0.1: 4096; Mistral v0.2+: 0 / null).
228    pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
229        Self::from_base(Self::from_def_base(def), 10_000.0, false)
230    }
231}
232
233struct LlamaFamilyConfigBase {
234    hidden_size: usize,
235    intermediate_size: usize,
236    num_heads: usize,
237    num_kv_heads: usize,
238    head_dim: usize,
239    num_layers: usize,
240    vocab_size: usize,
241    max_seq_len: usize,
242    rms_norm_eps: f32,
243    rope_theta_opt: Option<f64>,
244    sliding_window: usize,
245}
246
247/// Per-layer weights. `Box<dyn Linear<B>>` means each projection can be
248/// Dense / GPTQ / AWQ / GGUF without the surrounding code caring.
249pub struct LlamaFamilyLayer<B: QuantLlmBackend + BackendMoeFused> {
250    pub input_ln_w: B::Buffer,
251    pub qkv_proj: Box<dyn Linear<B>>,
252    /// QK-norm weight per head: `[head_dim]`. Optional for non-Qwen3 derivatives.
253    pub q_norm_w: Option<B::Buffer>,
254    pub k_norm_w: Option<B::Buffer>,
255    pub o_proj: Box<dyn Linear<B>>,
256    pub post_ln_w: B::Buffer,
257    pub gate_up_proj: Box<dyn Linear<B>>,
258    pub down_proj: Box<dyn Linear<B>>,
259}
260
261/// Precomputed RoPE cos/sin tables (shape `[max_seq, head_dim / 2]` each).
262pub struct RopeCache<B: QuantLlmBackend + BackendMoeFused> {
263    pub cos: B::Buffer,
264    pub sin: B::Buffer,
265}
266
267/// Reusable per-layer scratch buffers sized for `max_tokens` tokens of a
268/// single forward pass (prefill or decode step).
269///
270/// Sized lazily on first use so tiny decode steps don't pay for prefill-sized
271/// buffers. Grows monotonically when a larger prefill arrives.
272pub struct LlamaFamilyScratch<B: QuantLlmBackend + BackendMoeFused> {
273    /// Residual stream — wrapped in Option so decode_internal can
274    /// `.take()` it without needing an alloc placeholder.
275    ///
276    /// Why this matters for graph capture: the old pattern was
277    /// `mem::replace(&mut scratch.residual, B::alloc(1))` which creates a
278    /// 1-element buffer at every decode step. When graph capture is on,
279    /// that alloc-during-capture + drop-after-capture pair surfaces as
280    /// cuMemFreeAsync(INVALID_VALUE) because the free tries to release a
281    /// pointer the captured graph may still reference. Option::take leaves
282    /// None and moves the real buffer into a local — no spurious alloc.
283    pub residual: Option<B::Buffer>,
284    pub norm_out: B::Buffer,
285    pub qkv_out: B::Buffer,
286    // ── Per-item scratch for batched decode path ──────────────────────
287    // decode_batch_internal runs tokens=M batched ops for the GEMM-heavy
288    // half (norm, qkv_proj, split_qkv, o_proj, post_norm, gate_up, silu,
289    // down, residual_add) but must loop per-item for rope + KV append +
290    // attention (each item has its own KV cache at a different kv_len).
291    // These single-item buffers hold item i's slice during that loop.
292    /// Item-scope q_buf slice, sized `q_dim`.
293    pub q_single: B::Buffer,
294    pub k_single: B::Buffer,
295    pub v_single: B::Buffer,
296    pub q_head_major_single: B::Buffer,
297    pub k_head_major_single: B::Buffer,
298    pub v_head_major_single: B::Buffer,
299    pub attn_head_major_single: B::Buffer,
300    pub attn_flat_single: B::Buffer,
301    /// Batched logits output, sized `max_tokens * vocab_size`. Used only
302    /// in decode_batch; prefill/single-decode use the regular `logits`.
303    pub batch_logits: B::Buffer,
304    /// Token-major Q/K/V right after `split_qkv`. Stride: heads * hd per row.
305    pub q_buf: B::Buffer,
306    pub k_buf: B::Buffer,
307    pub v_buf: B::Buffer,
308    /// Head-major Q produced by `qk_norm_rope` — fed into `flash_attention`.
309    pub q_head_major: B::Buffer,
310    /// Head-major K/V staging — produced by `qk_norm_rope`, consumed by
311    /// `kv_cache_append_head_major` (no reuse after append).
312    pub k_head_major: B::Buffer,
313    pub v_head_major: B::Buffer,
314    /// Head-major attention output from `flash_attention`.
315    pub attn_head_major_out: B::Buffer,
316    /// Token-major attention output after `transpose_head_to_token`.
317    pub attn_flat: B::Buffer,
318    pub o_proj_out: B::Buffer,
319    pub gate_up_out: B::Buffer,
320    pub silu_out: B::Buffer,
321    pub mlp_out: B::Buffer,
322    /// Paged batched dispatch scratch (Phase 4b). Sized for
323    /// `FERRUM_PAGED_MAX_SEQS × q_dim` so multi-seq decode can fan
324    /// in M items' Q into a single buffer for one batched
325    /// `paged_decode_attention(num_seqs=M)` call. `None` when paged
326    /// mode is off.
327    pub paged_batch_q: Option<B::Buffer>,
328    pub paged_batch_o: Option<B::Buffer>,
329    /// Stacked per-seq block tables for batched paged dispatch.
330    /// Layout: `[max_M, max_blocks_per_seq]` u32. Written
331    /// host-side per decode_batch step.
332    pub paged_batch_block_tables: Option<B::Buffer>,
333    /// Stacked per-seq context lengths for batched paged dispatch
334    /// (`[max_M]` u32).
335    pub paged_batch_context_lens: Option<B::Buffer>,
336    /// `max_blocks_per_seq` value baked into the stacked block_tables
337    /// stride. Set when `paged_batch_block_tables` is allocated.
338    pub paged_max_blocks_per_seq: usize,
339    /// Engine-side max concurrent sequences (= `FERRUM_PAGED_MAX_SEQS`).
340    /// Pinned at the first `enable_paged_batch` so the unified-forward
341    /// scratch sizes (`unified_cu_seqlens_q`, `unified_pos_offsets`,
342    /// `unified_block_tables`, `unified_packed_*`) are big enough for
343    /// any subsequent batch up to that bound.
344    pub paged_max_seqs: usize,
345    /// Per-item RoPE positions for the batched-decode path (`[max_M]`
346    /// i32 / u32). Written host-side once per batched-decode step from
347    /// each request's `pos` field, read by the batched
348    /// `qk_norm_rope_batched_per_item` CUDA kernel.
349    pub batch_positions: B::Buffer,
350    /// Per-item input token ids for the batched-decode path (`[max_M]`
351    /// u32). Written once per call before forward; read by the
352    /// graph-capture-friendly `embedding_lookup_batched_dyn` variant.
353    pub batch_tokens: B::Buffer,
354    /// Per-item KV-cache length BEFORE this step's kv_append
355    /// (`[max_M]` u32). Used by `kv_cache_append_batched_per_cache_dyn`
356    /// to write at the right slot for graph replay.
357    pub batch_kv_lens_pre: B::Buffer,
358    /// Per-item KV-cache length AFTER this step's kv_append
359    /// (`[max_M]` u32 = pre + 1). Used by
360    /// `flash_attention_batched_per_cache_dyn` for the attention
361    /// reduce window length.
362    pub batch_kv_lens_post: B::Buffer,
363    /// Output buffers for the batched per-item qk_norm_rope kernel.
364    /// Same shape as q_buf / k_buf / v_buf — separate so the kernel
365    /// API can take `&input` and `&mut output` without aliasing.
366    pub q_normed_batched: B::Buffer,
367    pub k_normed_batched: B::Buffer,
368    pub v_normed_batched: B::Buffer,
369
370    // ── Unified mixed-batch scratch (chunked-prefill path; Step 5b) ─────
371    // Buffers sized for `M_total = sum(items[i].q_tokens.len())`. Grown
372    // on demand by `ensure_unified_scratch(M_total_max)`. Used only by
373    // the new `unified_forward_internal`; legacy `forward_layer_batched_decode`
374    // continues to use the per-item-stride scratch above.
375    pub unified_capacity: usize, // current allocated M_total slots
376    pub unified_residual: Option<B::Buffer>,
377    pub unified_norm_out: Option<B::Buffer>,
378    pub unified_qkv_out: Option<B::Buffer>,
379    pub unified_packed_q: Option<B::Buffer>,
380    pub unified_attn_out: Option<B::Buffer>,
381    pub unified_o_proj_out: Option<B::Buffer>,
382    pub unified_gate_up_out: Option<B::Buffer>,
383    pub unified_silu_out: Option<B::Buffer>,
384    pub unified_mlp_out: Option<B::Buffer>,
385    /// Per-item index buffers (i32-stored-as-f16): cu_seqlens_q is
386    /// length `max_seqs+1`, pos_offsets is `max_seqs`, block_tables is
387    /// `max_seqs * max_blocks_per_seq`. Sized once at first use to
388    /// match `paged_batch_*` capacity since they share `max_seqs`.
389    pub unified_cu_seqlens_q: Option<B::Buffer>,
390    pub unified_pos_offsets: Option<B::Buffer>,
391    pub unified_block_tables: Option<B::Buffer>,
392    /// Packed last-token hidden states for is_final_chunk items
393    /// (`[num_sampled, h]`). Used as input to lm_head.
394    pub unified_packed_normed: Option<B::Buffer>,
395    /// Packed logits output (`[num_sampled, vocab]`).
396    pub unified_packed_logits: Option<B::Buffer>,
397    /// Last token's hidden state (`[h]`). For prefill this is populated via
398    /// `copy_slice(residual, (seq_len-1)*h, ..)`; for decode `residual` already
399    /// holds only 1 row so `last_hidden` is unused on that path.
400    pub last_hidden: B::Buffer,
401    /// Final-norm output for the last token (`[h]`).
402    pub last_normed: B::Buffer,
403    /// lm_head logits (`[vocab]`).
404    pub logits: B::Buffer,
405    /// The max tokens-per-step this scratch has been sized for.
406    pub max_tokens: usize,
407}
408
409impl<B: QuantLlmBackend + BackendMoeFused> LlamaFamilyScratch<B> {
410    fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
411        let h = cfg.hidden_size;
412        let im = cfg.intermediate_size;
413        let q_dim = cfg.num_heads * cfg.head_dim;
414        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
415        let qkv_dim = q_dim + 2 * kv_dim;
416        let t = max_tokens;
417        Self {
418            residual: Some(B::alloc(t * h)),
419            norm_out: B::alloc(t * h),
420            qkv_out: B::alloc(t * qkv_dim),
421            q_buf: B::alloc(t * q_dim),
422            k_buf: B::alloc(t * kv_dim),
423            v_buf: B::alloc(t * kv_dim),
424            q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
425            k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
426            v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
427            attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
428            attn_flat: B::alloc(t * q_dim),
429            o_proj_out: B::alloc(t * h),
430            gate_up_out: B::alloc(t * 2 * im),
431            silu_out: B::alloc(t * im),
432            mlp_out: B::alloc(t * h),
433            last_hidden: B::alloc(h),
434            last_normed: B::alloc(h),
435            logits: B::alloc(cfg.vocab_size),
436            q_single: B::alloc(q_dim),
437            k_single: B::alloc(kv_dim),
438            v_single: B::alloc(kv_dim),
439            q_head_major_single: B::alloc(q_dim),
440            k_head_major_single: B::alloc(kv_dim),
441            v_head_major_single: B::alloc(kv_dim),
442            attn_head_major_single: B::alloc(q_dim),
443            attn_flat_single: B::alloc(q_dim),
444            batch_logits: B::alloc(t * cfg.vocab_size),
445            // Paged batched dispatch scratch. None until `enable_paged_batch`
446            // is called from `ensure_kv` once the model knows max_seqs +
447            // max_blocks_per_seq. This avoids paying the alloc cost when
448            // paged mode is off.
449            paged_batch_q: None,
450            paged_batch_o: None,
451            paged_batch_block_tables: None,
452            paged_batch_context_lens: None,
453            paged_max_blocks_per_seq: 0,
454            paged_max_seqs: 0,
455            batch_positions: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
456            batch_tokens: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
457            batch_kv_lens_pre: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
458            batch_kv_lens_post: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
459            q_normed_batched: B::alloc(t * q_dim),
460            k_normed_batched: B::alloc(t * kv_dim),
461            v_normed_batched: B::alloc(t * kv_dim),
462            unified_capacity: 0,
463            unified_residual: None,
464            unified_norm_out: None,
465            unified_qkv_out: None,
466            unified_packed_q: None,
467            unified_attn_out: None,
468            unified_o_proj_out: None,
469            unified_gate_up_out: None,
470            unified_silu_out: None,
471            unified_mlp_out: None,
472            unified_cu_seqlens_q: None,
473            unified_pos_offsets: None,
474            unified_block_tables: None,
475            unified_packed_normed: None,
476            unified_packed_logits: None,
477            max_tokens: t,
478        }
479    }
480
481    /// Grow unified-path scratch buffers to accommodate `m_total` query
482    /// tokens. Called lazily from `unified_forward_internal` so single-
483    /// path workloads (no chunked prefill) don't pay the alloc cost.
484    pub(crate) fn ensure_unified_scratch(
485        &mut self,
486        cfg: &LlamaFamilyConfig,
487        m_total: usize,
488        max_seqs: usize,
489        max_blocks_per_seq: usize,
490    ) {
491        if m_total <= self.unified_capacity
492            && self.unified_residual.is_some()
493            && self.unified_cu_seqlens_q.is_some()
494        {
495            return;
496        }
497        let cap = m_total.max(self.unified_capacity).max(1);
498        let h = cfg.hidden_size;
499        let q_dim = cfg.num_heads * cfg.head_dim;
500        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
501        let qkv_dim = q_dim + 2 * kv_dim;
502        let im = cfg.intermediate_size;
503        let v = cfg.vocab_size;
504        self.unified_residual = Some(B::alloc(cap * h));
505        self.unified_norm_out = Some(B::alloc(cap * h));
506        self.unified_qkv_out = Some(B::alloc(cap * qkv_dim));
507        self.unified_packed_q = Some(B::alloc(cap * q_dim));
508        self.unified_attn_out = Some(B::alloc(cap * q_dim));
509        self.unified_o_proj_out = Some(B::alloc(cap * h));
510        self.unified_gate_up_out = Some(B::alloc(cap * 2 * im));
511        self.unified_silu_out = Some(B::alloc(cap * im));
512        self.unified_mlp_out = Some(B::alloc(cap * h));
513        if self.unified_cu_seqlens_q.is_none() {
514            self.unified_cu_seqlens_q = Some(B::alloc_typed(
515                ferrum_kernels::backend::Dtype::U32,
516                max_seqs + 1,
517            ));
518            self.unified_pos_offsets = Some(B::alloc_typed(
519                ferrum_kernels::backend::Dtype::U32,
520                max_seqs,
521            ));
522            self.unified_block_tables = Some(B::alloc_typed(
523                ferrum_kernels::backend::Dtype::U32,
524                max_seqs * max_blocks_per_seq,
525            ));
526            self.unified_packed_normed = Some(B::alloc(max_seqs * h));
527            self.unified_packed_logits = Some(B::alloc(max_seqs * v));
528        }
529        self.unified_capacity = cap;
530    }
531
532    /// Allocate scratch for batched paged dispatch (Phase 4b). Called
533    /// lazily from `ensure_kv` once paged mode is enabled and we know
534    /// the pool dimensions. Idempotent.
535    fn enable_paged_batch(
536        &mut self,
537        cfg: &LlamaFamilyConfig,
538        max_seqs: usize,
539        max_blocks_per_seq: usize,
540    ) {
541        if self.paged_batch_q.is_some() {
542            return;
543        }
544        let q_dim = cfg.num_heads * cfg.head_dim;
545        self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
546        self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
547        self.paged_batch_block_tables = Some(B::alloc_typed(
548            ferrum_kernels::backend::Dtype::U32,
549            max_seqs * max_blocks_per_seq,
550        ));
551        self.paged_batch_context_lens = Some(B::alloc_typed(
552            ferrum_kernels::backend::Dtype::U32,
553            max_seqs,
554        ));
555        self.paged_max_blocks_per_seq = max_blocks_per_seq;
556        self.paged_max_seqs = max_seqs;
557    }
558}
559
560/// Qwen3 model — decoder-only LLM, one per (backend, weights) combination.
561///
562/// Holds all parameters, scratch space, RoPE cache, and per-sequence KV caches.
563///
564/// `B: BackendGraph + BackendQuantMarlin + BackendQuantGguf` because the decode hot path uses CUDA Graph capture/replay
565/// when the backend supports it; non-graph backends (Metal/CPU) inherit no-op
566/// defaults, so this bound is satisfied by every concrete `Backend`.
567///
568/// `K: KvDtypeKind = KvFp16` (Dim 5): selects the KV cache element type.
569/// `K = KvFp16` constructs only `LayerKvCache::Fp16(...)` variants;
570/// `K = KvInt8` constructs only `LayerKvCache::Int8(...)` variants. The
571/// `B: BackendKvDtype<KvInt8>` bound is structurally required by the
572/// enum and is satisfied by all backends via stub impls (Cpu/Metal) or
573/// the real impl (CUDA).
574pub struct LlamaFamilyModel<B: MoeLlmBackend, K: KvLayer<B> = KvFp16> {
575    pub cfg: LlamaFamilyConfig,
576    pub runtime_cfg: LlmRuntimeConfig,
577
578    /// Token embedding table. `None` for backbone-only models (e.g. the
579    /// Qwen3-TTS Talker, which embeds inputs externally and feeds via
580    /// `prefill_from_embeds`).
581    pub embed: Option<B::Buffer>,
582    pub layers: Vec<LlamaFamilyLayer<B>>,
583    pub final_norm_w: B::Buffer,
584    /// LM output head. `None` for backbone-only models.
585    pub lm_head: Option<Box<dyn Linear<B>>>,
586
587    pub rope: RopeCache<B>,
588    pub scratch: LlamaFamilyScratch<B>,
589
590    /// Per-sequence KV caches, one `Vec<KvCache<B>>` of length `num_layers`.
591    ///
592    /// Two layouts overlay this same map:
593    /// - **Contiguous mode** (default): each cache holds its own
594    ///   `[num_kv_heads, capacity, head_dim]` k/v buffers.
595    /// - **Paged mode** (`FERRUM_METAL_PAGED_KV=1`): k/v are unused
596    ///   placeholders; the real K/V live in [`Self::paged_pools`] and
597    ///   the cache's `block_table` + `context_lens` index into them.
598    pub kv_caches: HashMap<String, Vec<K::Layer>>,
599    /// Free pool of pre-allocated KV cache slots. Released caches return
600    /// here instead of being dropped, so their device pointers stay valid
601    /// across requests — critical for graph capture (pointers baked into
602    /// the captured graph would otherwise dangle).
603    kv_free_pool: Vec<Vec<K::Layer>>,
604
605    // ── Paged-KV multi-seq state (Phase 4) ─────────────────────────────
606    //
607    // Only populated when `FERRUM_METAL_PAGED_KV=1`. When set, every
608    // `kv_caches` entry becomes a "view" into the shared pool: its
609    // `k` / `v` buffers are placeholders; reads / writes go through
610    // `paged_pools[layer].k` / `.v` indexed via the cache's
611    // `block_table`. Multiple cache_ids share the same pool, with
612    // disjoint physical block sets owned by `paged_block_alloc`.
613    //
614    /// Shared K/V pools, one per layer. Sized at model load time for the
615    /// configured `MAX_RUNNING_REQUESTS × max_blocks_per_seq` blocks.
616    pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
617    /// Block allocator hands out physical block indices from the pool.
618    /// `Mutex` because the engine can call `ensure_kv` / `release_kv`
619    /// from multiple threads in concurrent serving.
620    pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
621    /// Paged-batch dispatch dimensions `(max_seqs, max_blocks_per_seq)`,
622    /// pinned at the first `ensure_kv` when paged-KV is on. Stored on
623    /// the model (not on scratch) so `ensure_scratch`'s realloc can
624    /// re-call `enable_paged_batch` after wiping scratch's
625    /// `paged_batch_block_tables` / `paged_batch_q` etc.
626    pub paged_dims: Option<(usize, usize)>,
627
628    // ── Graph capture state (CUDA only; harmless no-op on other backends) ──
629    /// Count of eager decode steps run so far. After `GRAPH_WARMUP`, the
630    /// next step captures the decode flow as a graph.
631    pub(crate) graph_warmup: usize,
632    /// True if capture was attempted but failed (e.g. backend doesn't
633    /// support graph capture). Stops further attempts, falls back to eager.
634    pub(crate) graph_capture_failed: bool,
635    /// Same warmup counter for the batched-decode path.
636    pub(crate) batched_graph_warmup: usize,
637    /// True if batched graph capture failed.
638    pub(crate) batched_graph_failed: bool,
639    /// Set of `m_padded` values (as u64 graph keys) for which a batched
640    /// graph has been captured. Multi-slot via cuda.rs's HashMap-keyed
641    /// graph cache — different batch shapes don't thrash a single slot.
642    pub(crate) batched_graph_keys_seen: std::collections::HashSet<u64>,
643    /// Cache IDs for which device-pointer scratch is currently populated.
644    /// Populate only re-runs when the batch composition changes (new
645    /// requests joined / requests finished). Hot-path optimization:
646    /// avoids 3 sync cuMemcpyHtoD_v2's per decode token (~5% TPOT).
647    pub(crate) batched_pointers_for: Option<Vec<String>>,
648    /// CUDA-graph state for the unified_forward path. Mirrors the
649    /// `batched_graph_*` triple but keyed on `(m_total, num_seqs)`
650    /// so different concurrency levels each get their own cached
651    /// graph instead of thrashing a single slot.
652    pub(crate) unified_graph_warmup: usize,
653    pub(crate) unified_graph_failed: bool,
654    pub(crate) unified_graph_keys_seen: std::collections::HashSet<u64>,
655}
656
657impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyModel<B, K> {
658    /// Build a Qwen3 model from weights provided by the loader.
659    ///
660    /// The loader decides per-projection whether to instantiate DenseLinear,
661    /// GptqLinear, etc. — this code doesn't care.
662    pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
663        // Invalidate any graph from a previously-loaded model. The captured
664        // graph references the old model's scratch buffers; a fresh model
665        // gets fresh scratch, so reusing the graph would read/write freed
666        // pointers. Matters for test suites where multiple models coexist.
667        {
668            let mut ctx = B::new_context();
669            B::reset_all_graphs(&mut ctx);
670        }
671        let rope = build_rope_cache::<B>(&cfg);
672        let scratch = LlamaFamilyScratch::alloc(&cfg, 1); // decode-sized; prefill resizes
673
674        // Embedding: plain tensor (no projection math, just lookup).
675        let embed = loader.load_tensor("model.embed_tokens.weight")?;
676
677        // Per-layer weights.
678        let mut layers = Vec::with_capacity(cfg.num_layers);
679        for li in 0..cfg.num_layers {
680            let prefix = format!("model.layers.{li}");
681            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
682            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
683            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
684            let post_ln_w =
685                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
686            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
687            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
688
689            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
690                let q = loader
691                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
692                    .ok();
693                let k = loader
694                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
695                    .ok();
696                (q, k)
697            } else {
698                (None, None)
699            };
700
701            layers.push(LlamaFamilyLayer {
702                input_ln_w,
703                qkv_proj,
704                q_norm_w,
705                k_norm_w,
706                o_proj,
707                post_ln_w,
708                gate_up_proj,
709                down_proj,
710            });
711        }
712
713        let final_norm_w = loader.load_tensor("model.norm.weight")?;
714
715        // LM head: either dedicated `lm_head.weight` or tied to embedding.
716        // Many models (Qwen3-4B, Llama-3.2-1B, some Qwen2.5) use TIED
717        // embeddings — lm_head shares weights with model.embed_tokens. When
718        // no dedicated lm_head tensor exists, re-load the embed tensor as a
719        // DenseLinear. This duplicates the buffer (memory cost = vocab*h*2
720        // bytes, e.g. ~770MB for Qwen3-4B) but keeps the Linear trait's
721        // owned-weights invariant. Sharing via Arc is a future optimisation.
722        let lm_head = if loader.has_tensor("lm_head.weight") {
723            loader.load_linear("lm_head")?
724        } else {
725            tracing::info!(
726                "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
727            );
728            let as_linear = loader.load_linear("model.embed_tokens")?;
729            // Sanity check: shape must be [vocab, hidden].
730            if as_linear.out_features() != cfg.vocab_size
731                || as_linear.in_features() != cfg.hidden_size
732            {
733                return Err(ferrum_types::FerrumError::model(format!(
734                    "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
735                    as_linear.out_features(),
736                    as_linear.in_features(),
737                    cfg.vocab_size,
738                    cfg.hidden_size
739                )));
740            }
741            as_linear
742        };
743
744        let runtime_cfg = cfg.to_runtime();
745        Ok(Self {
746            cfg,
747            runtime_cfg,
748            embed: Some(embed),
749            layers,
750            final_norm_w,
751            lm_head: Some(lm_head),
752            rope,
753            scratch,
754            kv_caches: HashMap::new(),
755            kv_free_pool: Vec::new(),
756            paged_pools: None,
757            paged_block_alloc: None,
758            paged_dims: None,
759            graph_warmup: 0,
760            graph_capture_failed: false,
761            batched_graph_warmup: 0,
762            batched_graph_failed: false,
763            batched_graph_keys_seen: std::collections::HashSet::new(),
764            batched_pointers_for: None,
765            unified_graph_warmup: 0,
766            unified_graph_failed: false,
767            unified_graph_keys_seen: std::collections::HashSet::new(),
768        })
769    }
770
771    /// Build a backbone-only Qwen3 transformer stack (no embed, no lm_head).
772    ///
773    /// Intended for composing the transformer inside a larger model where
774    /// embedding and output-head logic differs from the standard LLM path —
775    /// e.g. Qwen3-TTS Talker uses dual text/codec embeddings with a projection
776    /// MLP, and a codec_head output. The caller drives forward via
777    /// `prefill_from_embeds` / `decode_from_embed`.
778    ///
779    /// Loader must provide: per-layer weights under `model.layers.{i}.*` and
780    /// the final `model.norm.weight`. `model.embed_tokens` and `lm_head`
781    /// are NOT read.
782    pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
783        // See `new` — invalidate stale graph referring to prior model's scratch.
784        {
785            let mut ctx = B::new_context();
786            B::reset_all_graphs(&mut ctx);
787        }
788        let rope = build_rope_cache::<B>(&cfg);
789        let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
790
791        let mut layers = Vec::with_capacity(cfg.num_layers);
792        for li in 0..cfg.num_layers {
793            let prefix = format!("model.layers.{li}");
794            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
795            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
796            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
797            let post_ln_w =
798                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
799            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
800            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
801
802            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
803                let q = loader
804                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
805                    .ok();
806                let k = loader
807                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
808                    .ok();
809                (q, k)
810            } else {
811                (None, None)
812            };
813
814            layers.push(LlamaFamilyLayer {
815                input_ln_w,
816                qkv_proj,
817                q_norm_w,
818                k_norm_w,
819                o_proj,
820                post_ln_w,
821                gate_up_proj,
822                down_proj,
823            });
824        }
825
826        let final_norm_w = loader.load_tensor("model.norm.weight")?;
827
828        let runtime_cfg = cfg.to_runtime();
829        Ok(Self {
830            cfg,
831            runtime_cfg,
832            embed: None,
833            layers,
834            final_norm_w,
835            lm_head: None,
836            rope,
837            scratch,
838            kv_caches: HashMap::new(),
839            kv_free_pool: Vec::new(),
840            paged_pools: None,
841            paged_block_alloc: None,
842            paged_dims: None,
843            graph_warmup: 0,
844            graph_capture_failed: false,
845            batched_graph_warmup: 0,
846            batched_graph_failed: false,
847            batched_graph_keys_seen: std::collections::HashSet::new(),
848            batched_pointers_for: None,
849            unified_graph_warmup: 0,
850            unified_graph_failed: false,
851            unified_graph_keys_seen: std::collections::HashSet::new(),
852        })
853    }
854
855    /// Grow scratch buffers if `tokens` exceeds the current sizing.
856    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
857        if self.scratch.max_tokens < tokens {
858            // Any captured decode graph holds pointers to the old scratch
859            // buffers; those are about to be freed. Invalidate ALL captured
860            // graphs (both single-item and per-m_padded batched) — every
861            // captured kernel-arg pointer into scratch is stale.
862            {
863                let mut ctx = B::new_context();
864                B::reset_all_graphs(&mut ctx);
865            }
866            self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
867            self.graph_warmup = 0;
868            self.graph_capture_failed = false;
869            self.batched_graph_keys_seen.clear();
870            self.batched_graph_warmup = 0;
871            self.batched_graph_failed = false;
872            self.unified_graph_keys_seen.clear();
873            self.unified_graph_warmup = 0;
874            self.unified_graph_failed = false;
875            // Realloc wiped paged_batch_*. Re-enable using the dims
876            // pinned at first ensure_kv. Without this, the next
877            // `forward_layer_batched_decode` panics on
878            // `paged_batch_block_tables missing`.
879            if let Some((max_seqs, max_blocks_per_seq)) = self.paged_dims {
880                self.scratch
881                    .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
882            }
883        }
884    }
885
886    /// Ensure per-layer KV caches exist for `cache_id`, pre-allocated to
887    /// `max_seq_len` slots per head. Enables the in-place
888    /// `kv_cache_append_head_major` path — no realloc per layer.
889    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
890        if self.kv_caches.contains_key(cache_id) {
891            return;
892        }
893        let nkv = self.cfg.num_kv_heads;
894        let hd = self.cfg.head_dim;
895        // KV capacity defaults to a chat-friendly 4096 to keep the working
896        // set sane on a 32 GB Mac (a 48-layer / 4-kv-head / 128-head_dim
897        // model spends ~786 MB on 4096 KV slots, vs ~6 GB on the model's
898        // declared 32K which would push the 17 GB Qwen3-30B-A3B model into
899        // swap). `FERRUM_KV_CAPACITY=N` overrides; clamp to the model's
900        // declared max so we never lie to the model about its window.
901        let model_max = self.cfg.max_seq_len;
902        // 512 in 0.7.2 — matches the value used in
903        // docs/bench/macos-2026-05-02 to get the published numbers.
904        // pre-0.7.2 default of 4096 was safe only because paged-KV was
905        // opt-in (pool wasn't allocated). With paged-KV now on by
906        // default + MAX_SEQS=32, the pool occupies physical memory:
907        // ~3 GB on Qwen3-30B-A3B Q4_K_M leaves 18 GB weights + 3 GB pool
908        // = 21 GB, fits comfortably on a 32 GB Mac. Long-context users
909        // can `FERRUM_KV_CAPACITY=4096` and accept lower max_seqs.
910        let runtime_env = llama_family_runtime_env();
911        let max = runtime_env.kv_capacity_for_model(model_max);
912
913        // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches every cache
914        // for this model into block-table-indirect layout. Kernels from
915        // PR #68 (decode read) + PR #69 (decode write) handle the
916        // indirect addressing; the LlamaFamily decode path below picks
917        // them up automatically by checking `cache.block_size > 0`.
918        //
919        // Pool sizing: round capacity up to a multiple of block_size,
920        // identity-assign logical→physical block. Memory footprint is
921        // the same as contiguous (within block_size rounding); the
922        // benefit only shows up under multi-seq sharing in Phase 4+.
923        // Default ON when the backend supports paged-KV (Metal). Users
924        // can force off with `FERRUM_METAL_PAGED_KV=0`. The flag was
925        // opt-in pre-0.7.2; flipping the default so default `ferrum
926        // serve` matches the bench-quality numbers without requiring
927        // env-var knowledge.
928        let paged = runtime_env.paged_kv_enabled::<B>();
929        const PAGED_BLOCK_SIZE: usize = 16;
930
931        // Phase 4 shared-pool sizing. The pool sees ALL concurrent
932        // sequences; per-cache_id state just owns indices into it.
933        // Default 32: covers c=16 burst with 2× headroom for the
934        // fresh-cache-id-per-request pattern that bench/server harnesses
935        // use. Pool memory is `max_seqs × max_blocks_per_seq` total
936        // blocks — we lowered DEFAULT_KV_CAPACITY to 2048 so this 2× max_seqs
937        // bump keeps the pool footprint identical to the pre-0.7.2 default.
938        let max_seqs = runtime_env.paged_max_seqs;
939        let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
940        let total_pool_blocks = max_seqs * max_blocks_per_seq;
941
942        // Lazy-allocate the shared paged pools on the FIRST paged
943        // ensure_kv call. Pools are big — for Llama-8B (8 kv_heads,
944        // head_dim=128) at 16 seqs × 256 blocks × 16 slots = 65536 KV
945        // slots: 65536 * 8 * 128 * 4 = 256 MB per layer × 32 layers
946        // = 8 GB total. Sized this large only because `max_seqs=16`
947        // is the default; lower it via env to shrink.
948        if paged && self.paged_pools.is_none() {
949            let mut pools = Vec::with_capacity(self.cfg.num_layers);
950            for _ in 0..self.cfg.num_layers {
951                let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
952                pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
953            }
954            self.paged_pools = Some(pools);
955            self.paged_block_alloc = Some(std::sync::Mutex::new(
956                crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
957            ));
958        }
959        // Phase 4b: ensure batched-dispatch scratch is allocated whenever
960        // paged is on. Idempotent — re-init is a no-op if already
961        // sized. Has to live outside the `paged_pools.is_none()` branch
962        // because `ensure_scratch` may have replaced the scratch struct
963        // since the pools were first allocated.
964        if paged {
965            self.scratch
966                .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
967            // Pin dims on the model so `ensure_scratch`'s realloc can
968            // re-init paged_batch_* after wiping scratch.
969            self.paged_dims = Some((max_seqs, max_blocks_per_seq));
970        }
971
972        // Try pool first — reused buffers have stable device pointers,
973        // so a captured decode graph can be replayed for this request too.
974        // K::NAME selects which `LayerKvCache` variant to construct:
975        // K-aware allocation: K::alloc_paged / K::alloc_contig pick the
976        // right cache layout (FP16 → KvCache, INT8 → KvCacheQuant) per the
977        // model's K marker. INT8 supports paged mode only — KvInt8::alloc_contig
978        // panics, surfacing the misconfiguration here at first ensure_kv.
979        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
980            (0..self.cfg.num_layers)
981                .map(|_| {
982                    if paged {
983                        K::alloc_paged(max_blocks_per_seq, PAGED_BLOCK_SIZE, nkv, hd)
984                    } else {
985                        K::alloc_contig(max, nkv, hd)
986                    }
987                })
988                .collect()
989        });
990
991        // Allocate physical blocks for THIS cache_id from the shared
992        // pool. We allocate all `max_blocks_per_seq` upfront for
993        // simplicity (matches contig's "pre-alloc to capacity"
994        // semantics); a smarter Phase 4b can grow on demand to save
995        // pool occupancy.
996        if paged {
997            let alloc_arc = self
998                .paged_block_alloc
999                .as_ref()
1000                .expect("paged_block_alloc must be initialised when paged=true");
1001            // Recover from a previously-poisoned mutex instead of panicking
1002            // (poison just means a prior holder panicked; the BlockAllocator
1003            // state is still intact since allocate_n is fail-safe).
1004            let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1005            let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
1006                Ok(idx) => idx,
1007                Err(e) => {
1008                    // Pool exhaustion is a back-pressure signal, not a crash.
1009                    // Drop the lock, return the cache to the free pool, and
1010                    // bail before inserting it into kv_caches. The downstream
1011                    // call will then fail with a clean per-request error
1012                    // ("ensure_kv must be called before ...") instead of
1013                    // dragging every other in-flight request down with it.
1014                    drop(alloc);
1015                    self.kv_free_pool.push(caches);
1016                    eprintln!(
1017                        "[ferrum] paged KV pool exhausted on ensure_kv for \
1018                         cache_id={cache_id:?}: {e}. Increase \
1019                         FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
1020                         throttle concurrent requests.",
1021                    );
1022                    return;
1023                }
1024            };
1025            // Write the block table to each layer's cache. All layers
1026            // share the same logical→physical mapping for this seq.
1027            // Also stash the host-side index list so release_kv can
1028            // return them to the allocator without a device readback.
1029            let mut padded = block_indices.clone();
1030            padded.resize(max_blocks_per_seq, 0);
1031            let mut ctx_tmp = B::new_context();
1032            for c in caches.iter_mut() {
1033                if let Some(bt) = K::block_table_mut(c) {
1034                    B::write_typed::<u32>(&mut ctx_tmp, bt, &padded);
1035                }
1036                *K::paged_block_indices_mut(c) = block_indices.clone();
1037            }
1038            B::sync(&mut ctx_tmp);
1039        }
1040
1041        // Reset logical length; buffers stay. No need to zero the memory —
1042        // the kv_cache_append writes new K/V in place, and attention only
1043        // reads up to `cache_len`.
1044        for c in caches.iter_mut() {
1045            K::set_len(c, 0);
1046            if let Some(cl) = K::context_lens_mut(c) {
1047                let mut ctx_tmp = B::new_context();
1048                B::write_typed::<u32>(&mut ctx_tmp, cl, &[0u32]);
1049                B::sync(&mut ctx_tmp);
1050            }
1051        }
1052        self.kv_caches.insert(cache_id.to_string(), caches);
1053    }
1054
1055    /// Run one transformer layer. Mutates `residual` in place.
1056    ///
1057    /// `pos_offset` is the absolute position of token 0 in this batch
1058    /// (decode: `pos`; prefill: 0). `tokens` is the batch size.
1059    #[allow(clippy::too_many_arguments)]
1060    pub(crate) fn forward_layer(
1061        &mut self,
1062        ctx: &mut B::Context,
1063        li: usize,
1064        cache_id: &str,
1065        residual: &mut B::Buffer,
1066        pos_offset: usize,
1067        tokens: usize,
1068    ) {
1069        let layer = &self.layers[li];
1070        let cfg = &self.cfg;
1071        let h = cfg.hidden_size;
1072        let nh = cfg.num_heads;
1073        let nkv = cfg.num_kv_heads;
1074        let hd = cfg.head_dim;
1075        let im = cfg.intermediate_size;
1076        let eps = cfg.rms_norm_eps;
1077        let q_dim = nh * hd;
1078        let kv_dim = nkv * hd;
1079
1080        // 1. Input RMSNorm
1081        let _t0 = if llama_family_runtime_env().decode_op_profile {
1082            B::sync(ctx);
1083            Some(std::time::Instant::now())
1084        } else {
1085            None
1086        };
1087        B::rms_norm(
1088            ctx,
1089            residual,
1090            &layer.input_ln_w,
1091            eps,
1092            &mut self.scratch.norm_out,
1093            tokens,
1094            h,
1095        );
1096        if let Some(t0) = _t0 {
1097            B::sync(ctx);
1098            NORM_TIME_US.fetch_add(
1099                t0.elapsed().as_micros() as u64,
1100                std::sync::atomic::Ordering::Relaxed,
1101            );
1102            NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1103        }
1104
1105        // 2. Fused QKV projection (Linear dispatches to Dense/GPTQ/AWQ/GGUF)
1106        let _t0 = if llama_family_runtime_env().decode_op_profile {
1107            B::sync(ctx);
1108            Some(std::time::Instant::now())
1109        } else {
1110            None
1111        };
1112        layer.qkv_proj.forward(
1113            ctx,
1114            &self.scratch.norm_out,
1115            &mut self.scratch.qkv_out,
1116            tokens,
1117        );
1118        if let Some(t0) = _t0 {
1119            B::sync(ctx);
1120            MATMUL_TIME_US.fetch_add(
1121                t0.elapsed().as_micros() as u64,
1122                std::sync::atomic::Ordering::Relaxed,
1123            );
1124            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1125        }
1126
1127        // 3-5. Fused split-QKV + QK-norm + RoPE + cache-write.
1128        //
1129        // Single Metal dispatch replaces the (split_qkv → 3× qk_norm_rope
1130        // → kv_cache_append_head_major) five-launch chain on the decode
1131        // hot path. Reads qkv_out once, writes Q to head-major scratch
1132        // and K/V straight into the pre-allocated KV cache slot at
1133        // `cache_len + tok`. Saves 4 dispatches per layer when the
1134        // backend implements the fused kernel; CPU and other backends
1135        // keep using the unfused chain via the Unsupported fallbacks.
1136        //
1137        // qk_mode: 1 = norm + RoPE (Qwen3); 2 = RoPE only (Llama).
1138        // V always passes apply_norm=0.
1139        let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
1140        let dummy = &layer.input_ln_w;
1141        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
1142        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
1143
1144        // Grab the per-layer KV cache up front so the deepest fusion can
1145        // write K/V straight into it.
1146        //
1147        // Paged mode: also need this layer's shared pool buffers
1148        // (self.paged_pools[li]). The pool is a separate field from
1149        // kv_caches, so we take a raw pointer to its (k, v) here while
1150        // we still hold &mut self, then deref via unsafe inside the
1151        // paged dispatch below. Safety: paged_pools is allocated once
1152        // and never resized; we don't touch self.paged_pools while the
1153        // pointer is in use.
1154        let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
1155            if let Some(pools) = self.paged_pools.as_mut() {
1156                let pool = &mut pools[li];
1157                Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
1158            } else {
1159                None
1160            };
1161        let caches = self
1162            .kv_caches
1163            .get_mut(cache_id)
1164            .expect("ensure_kv must be called before forward_layer");
1165        // Read shared metadata (variant-agnostic) once. The K-aware
1166        // attn section below re-borrows the right enum variant.
1167        let cache_len_before = K::len(&caches[li]);
1168        let cache_capacity = K::capacity(&caches[li]);
1169        let cache_block_size = K::block_size(&caches[li]);
1170
1171        // Defense in depth: refuse to write past the KV buffer. The
1172        // graceful path is the caller pre-checking via `kv_capacity()`
1173        // and either compacting or refusing the request; this panic only
1174        // fires when that contract is broken (and silent overflow would
1175        // otherwise corrupt the cache + adjacent allocations).
1176        if cache_len_before + tokens > cache_capacity {
1177            panic!(
1178                "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
1179                cache_len_before + tokens
1180            );
1181        }
1182
1183        // Paged path: K::paged_write fuses split_qkv_norm_rope + cache append
1184        // (FP16: into_paged_cache; INT8: split_qkv_norm_rope + int8_kv_append_paged).
1185        // K::paged_decode_attention reads from layer-local INT8 buffers or the
1186        // shared FP16 pool depending on K. Then K-agnostic post-attn tail.
1187        if cache_block_size > 0 {
1188            let (pool_k_ptr, pool_v_ptr) =
1189                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1190            // SAFETY: paged_pools is allocated once and never resized; the
1191            // raw pointers don't outlive this method scope.
1192            let pool_k = unsafe { &mut *pool_k_ptr };
1193            let pool_v = unsafe { &mut *pool_v_ptr };
1194
1195            K::paged_write(
1196                ctx,
1197                &mut caches[li],
1198                &self.scratch.qkv_out,
1199                q_norm_w,
1200                k_norm_w,
1201                &self.rope.cos,
1202                &self.rope.sin,
1203                &mut self.scratch.q_head_major,
1204                &mut self.scratch.k_head_major,
1205                &mut self.scratch.v_head_major,
1206                pool_k,
1207                pool_v,
1208                tokens,
1209                nh,
1210                nkv,
1211                hd,
1212                pos_offset,
1213                eps,
1214                qk_mode,
1215            )
1216            .expect("K::paged_write");
1217
1218            let new_len = cache_len_before + tokens;
1219            K::set_len(&mut caches[li], new_len);
1220
1221            let pool_k_imm = unsafe { &*pool_k_ptr };
1222            let pool_v_imm = unsafe { &*pool_v_ptr };
1223            K::paged_decode_attention(
1224                ctx,
1225                &mut caches[li],
1226                &self.scratch.q_head_major,
1227                pool_k_imm,
1228                pool_v_imm,
1229                &mut self.scratch.attn_head_major_out,
1230                nh,
1231                nkv,
1232                hd,
1233                new_len,
1234                tokens,
1235            )
1236            .expect("K::paged_decode_attention");
1237
1238            return self.forward_layer_post_attn(ctx, li, residual, tokens);
1239        }
1240
1241        // Non-paged (contig) path. INT8 path doesn't reach here:
1242        // KvInt8::alloc_contig panics in ensure_kv.
1243        let _qkr_t0 = if llama_family_runtime_env().decode_op_profile {
1244            B::sync(ctx);
1245            Some(std::time::Instant::now())
1246        } else {
1247            None
1248        };
1249        K::contig_write(
1250            ctx,
1251            &mut caches[li],
1252            &self.scratch.qkv_out,
1253            q_norm_w,
1254            k_norm_w,
1255            &self.rope.cos,
1256            &self.rope.sin,
1257            &mut self.scratch.q_head_major,
1258            &mut self.scratch.k_head_major,
1259            &mut self.scratch.v_head_major,
1260            &mut self.scratch.q_buf,
1261            &mut self.scratch.k_buf,
1262            &mut self.scratch.v_buf,
1263            tokens,
1264            nh,
1265            nkv,
1266            hd,
1267            pos_offset,
1268            eps,
1269            qk_mode,
1270        )
1271        .expect("K::contig_write");
1272        if let Some(t0) = _qkr_t0 {
1273            B::sync(ctx);
1274            QKR_TIME_US.fetch_add(
1275                t0.elapsed().as_micros() as u64,
1276                std::sync::atomic::Ordering::Relaxed,
1277            );
1278            QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1279        }
1280        let new_len = cache_len_before + tokens;
1281        K::set_len(&mut caches[li], new_len);
1282        let kv_stride = cache_capacity;
1283
1284        let _attn_t0 = if llama_family_runtime_env().decode_op_profile {
1285            B::sync(ctx);
1286            Some(std::time::Instant::now())
1287        } else {
1288            None
1289        };
1290        let attn_cfg = ferrum_kernels::backend::AttnConfig {
1291            num_heads: nh,
1292            num_kv_heads: nkv,
1293            head_dim: hd,
1294            causal: true,
1295            scale: 1.0 / (hd as f32).sqrt(),
1296            kv_seq_stride: kv_stride,
1297            sliding_window: cfg.sliding_window,
1298        };
1299        K::contig_decode_attention(
1300            ctx,
1301            &caches[li],
1302            &self.scratch.q_head_major,
1303            &mut self.scratch.attn_head_major_out,
1304            attn_cfg,
1305            tokens,
1306            pos_offset,
1307        )
1308        .expect("K::contig_decode_attention");
1309        let _ = q_dim;
1310        let _ = kv_dim;
1311        let _ = dummy;
1312        if let Some(t0) = _attn_t0 {
1313            B::sync(ctx);
1314            ATTN_TIME_US.fetch_add(
1315                t0.elapsed().as_micros() as u64,
1316                std::sync::atomic::Ordering::Relaxed,
1317            );
1318            ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1319        }
1320
1321        self.forward_layer_post_attn(ctx, li, residual, tokens);
1322    }
1323
1324    /// Post-attention tail of `forward_layer`: untranspose Q (if needed),
1325    /// O-proj, fused residual+post_norm, gate_up_proj, SwiGLU, down_proj,
1326    /// final residual add. K-agnostic — reads `self.scratch.attn_head_major_out`
1327    /// which both the FP16 and INT8 attn paths populate.
1328    pub(crate) fn forward_layer_post_attn(
1329        &mut self,
1330        ctx: &mut B::Context,
1331        li: usize,
1332        residual: &mut B::Buffer,
1333        tokens: usize,
1334    ) {
1335        let layer = &self.layers[li];
1336        let cfg = &self.cfg;
1337        let h = cfg.hidden_size;
1338        let nh = cfg.num_heads;
1339        let hd = cfg.head_dim;
1340        let im = cfg.intermediate_size;
1341        let eps = cfg.rms_norm_eps;
1342
1343        // 7. Untranspose head-major → token-major for O-proj input.
1344        let attn_token_major = if tokens == 1 {
1345            &self.scratch.attn_head_major_out
1346        } else {
1347            B::transpose_head_to_token(
1348                ctx,
1349                &self.scratch.attn_head_major_out,
1350                &mut self.scratch.attn_flat,
1351                tokens,
1352                nh,
1353                hd,
1354            );
1355            &self.scratch.attn_flat
1356        };
1357
1358        // 8. O projection.
1359        layer
1360            .o_proj
1361            .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1362
1363        // 9. Fused residual-add + post-attention RMSNorm.
1364        B::fused_add_rms_norm(
1365            ctx,
1366            residual,
1367            &self.scratch.o_proj_out,
1368            &layer.post_ln_w,
1369            eps,
1370            &mut self.scratch.norm_out,
1371            tokens,
1372            h,
1373        );
1374
1375        // 10. Fused gate+up projection.
1376        layer.gate_up_proj.forward(
1377            ctx,
1378            &self.scratch.norm_out,
1379            &mut self.scratch.gate_up_out,
1380            tokens,
1381        );
1382
1383        // 11. SwiGLU: silu(gate) * up.
1384        B::fused_silu_mul_split(
1385            ctx,
1386            &self.scratch.gate_up_out,
1387            &mut self.scratch.silu_out,
1388            tokens,
1389            im,
1390        );
1391
1392        // 12. Down projection.
1393        layer.down_proj.forward(
1394            ctx,
1395            &self.scratch.silu_out,
1396            &mut self.scratch.mlp_out,
1397            tokens,
1398        );
1399
1400        // 13. Final residual add.
1401        B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1402    }
1403
1404    /// Multi-position decode-verify: run one forward pass over `tokens`
1405    /// starting at the cache's current end position, write their K/V
1406    /// into the KV cache, and return logits for ALL `tokens.len()`
1407    /// positions as a flat `Vec<f32>` of length `seq_len * vocab_size`.
1408    ///
1409    /// Used by speculative decoding: target receives
1410    /// `[last_token, draft_0, ..., draft_{N-1}]` (N+1 inputs) and produces
1411    /// N+1 logit rows in a single forward instead of N+1 sequential
1412    /// decode() calls. Positions are implicit — the model looks up
1413    /// `pos_offset = cache.len` the same way prefill_internal does, so
1414    /// chunked prefill semantics carry over for free.
1415    pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1416        let seq_len = tokens.len();
1417        assert!(seq_len > 0, "forward_verify called with empty tokens");
1418        self.ensure_scratch(seq_len);
1419        self.ensure_kv(cache_id);
1420
1421        let h = self.cfg.hidden_size;
1422        let vocab = self.cfg.vocab_size;
1423
1424        let pos_offset = self
1425            .kv_caches
1426            .get(cache_id)
1427            .and_then(|layers| layers.first())
1428            .map(|c| K::len(c))
1429            .unwrap_or(0);
1430
1431        let mut ctx = B::new_context();
1432        let mut residual = self
1433            .scratch
1434            .residual
1435            .take()
1436            .expect("scratch residual missing (previous call didn't restore)");
1437
1438        let embed = self
1439            .embed
1440            .as_ref()
1441            .expect("forward_verify called on backbone-only model (no embed)");
1442        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1443
1444        for li in 0..self.cfg.num_layers {
1445            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1446        }
1447
1448        // RMSNorm on ALL seq_len positions (prefill_internal only norms
1449        // the last one; verify needs the full grid).
1450        B::rms_norm(
1451            &mut ctx,
1452            &residual,
1453            &self.final_norm_w,
1454            self.cfg.rms_norm_eps,
1455            &mut self.scratch.norm_out,
1456            seq_len,
1457            h,
1458        );
1459
1460        // LM head applied to all positions → `seq_len * vocab` logits.
1461        // Reuses the existing `batch_logits` scratch (sized max_tokens *
1462        // vocab) so no extra allocation.
1463        let lm_head = self
1464            .lm_head
1465            .as_ref()
1466            .expect("forward_verify called on backbone-only model (no lm_head)");
1467        lm_head.forward(
1468            &mut ctx,
1469            &self.scratch.norm_out,
1470            &mut self.scratch.batch_logits,
1471            seq_len,
1472        );
1473
1474        B::sync(&mut ctx);
1475        self.scratch.residual = Some(residual);
1476        B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1477    }
1478
1479    /// Prefill: process `tokens` prompt tokens in a single batch, return
1480    /// `[vocab_size]` logits for the last position.
1481    ///
1482    /// Supports incremental prefill: if the KV cache for `cache_id` already
1483    /// contains earlier tokens, the new chunk's positions are computed as
1484    /// `[kv_len, kv_len + tokens.len())` so RoPE and causal masking stay
1485    /// aligned. Used by the engine's chunked-prefill path.
1486    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1487        let seq_len = tokens.len();
1488        assert!(seq_len > 0, "prefill called with empty token list");
1489        self.ensure_scratch(seq_len);
1490        self.ensure_kv(cache_id);
1491
1492        // Starting position for this chunk — 0 for a fresh prefill, kv_len
1493        // for the second+ chunk of a split prefill.
1494        let pos_offset = self
1495            .kv_caches
1496            .get(cache_id)
1497            .and_then(|layers| layers.first())
1498            .map(|c| K::len(c))
1499            .unwrap_or(0);
1500
1501        let h = self.cfg.hidden_size;
1502        let vocab = self.cfg.vocab_size;
1503        let mut ctx = B::new_context();
1504
1505        // Move `residual` out of `scratch` to work around the borrow checker:
1506        // `forward_layer` re-borrows `&mut self` to reach `self.layers` /
1507        // `self.kv_caches`, which would conflict with an outstanding
1508        // `&mut self.scratch.residual`. Use Option::take to move it out
1509        // (no placeholder alloc → no transient cuMemFreeAsync that could
1510        // corrupt stream pool state after graph ops on Blackwell).
1511        let mut residual = self
1512            .scratch
1513            .residual
1514            .take()
1515            .expect("scratch residual missing (previous call didn't restore)");
1516        let embed = self
1517            .embed
1518            .as_ref()
1519            .expect("prefill_internal called on backbone-only model (no embed)");
1520        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1521
1522        let prefill_profile = llama_family_runtime_env().prefill_op_profile;
1523        let prefill_t0 = if prefill_profile {
1524            B::sync(&mut ctx);
1525            Some(std::time::Instant::now())
1526        } else {
1527            None
1528        };
1529
1530        for li in 0..self.cfg.num_layers {
1531            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1532        }
1533
1534        if let Some(t0) = prefill_t0 {
1535            B::sync(&mut ctx);
1536            let total_us = t0.elapsed().as_micros() as u64;
1537            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1538            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1539            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1540            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1541            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1542            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1543            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1544            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1545            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1546            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1547            eprintln!(
1548                "[prefill-profile] tokens={} layers total={} ms",
1549                seq_len,
1550                total_us / 1000
1551            );
1552            let bucket = |label: &str, n: u64, us: u64| {
1553                if n > 0 {
1554                    eprintln!(
1555                        "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
1556                        n,
1557                        us / 1000,
1558                        us / n
1559                    );
1560                }
1561            };
1562            bucket("flash_attn", attn_n, attn_us);
1563            bucket("qk_norm_rope", qkr_n, qkr_us);
1564            bucket("matmuls", mm_n, mm_us);
1565            bucket("norms", norm_n, norm_us);
1566            bucket("other", other_n, other_us);
1567        }
1568
1569        // Take the last token's hidden state: residual[(seq_len-1)*h .. seq_len*h]
1570        B::copy_slice(
1571            &mut ctx,
1572            &residual,
1573            (seq_len - 1) * h,
1574            &mut self.scratch.last_hidden,
1575            0,
1576            h,
1577        );
1578
1579        // Final RMSNorm on the last hidden.
1580        B::rms_norm(
1581            &mut ctx,
1582            &self.scratch.last_hidden,
1583            &self.final_norm_w,
1584            self.cfg.rms_norm_eps,
1585            &mut self.scratch.last_normed,
1586            1,
1587            h,
1588        );
1589
1590        // LM head (m=1 — triggers GEMV on MetalBackend).
1591        let lm_head = self
1592            .lm_head
1593            .as_ref()
1594            .expect("prefill_internal called on backbone-only model (no lm_head)");
1595        lm_head.forward(
1596            &mut ctx,
1597            &self.scratch.last_normed,
1598            &mut self.scratch.logits,
1599            1,
1600        );
1601
1602        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
1603        // buffer's CPU pointer without flushing the command buffer, so the
1604        // GPU must complete all pending work first or we read stale/random
1605        // data. CUDA's to_vec does an internal stream.synchronize, making
1606        // the call redundant there (~50µs/step cost), but correctness on
1607        // Metal requires the explicit flush here.
1608        B::sync(&mut ctx);
1609
1610        // Restore residual into scratch for reuse on the next call.
1611        self.scratch.residual = Some(residual);
1612
1613        B::to_vec(&self.scratch.logits, vocab)
1614    }
1615
1616    /// Decode: process 1 token at position `pos`, return `[vocab_size]` logits.
1617    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1618        self.ensure_scratch(1);
1619        self.ensure_kv(cache_id);
1620
1621        let h = self.cfg.hidden_size;
1622        let vocab = self.cfg.vocab_size;
1623
1624        // Context creation is cheap (CUDA reuses the process-global stream).
1625        // The captured graph lives in a process-global slot, not on ctx.
1626        let mut ctx = B::new_context();
1627
1628        // Graph capture is opt-in via FERRUM_CUDA_GRAPH=1. Replay is currently
1629        // single-request-only on Blackwell + CUDA 12.8 (see
1630        // docs/phase-e-cuda-status.md). In pure eager mode, we skip the
1631        // per-step device-state memcpy_htod trio entirely.
1632        const GRAPH_WARMUP: usize = 3;
1633        let graph_enabled = llama_family_runtime_env().cuda_graph;
1634
1635        if graph_enabled {
1636            // Refresh device-side dynamic state (token/pos/kv_len) before
1637            // replay — captured graph reads these from device buffers.
1638            B::set_decode_state(&mut ctx, token, pos);
1639
1640            // Fast path: graph replay (if available). Single-item path
1641            // uses key=SINGLE_ITEM_GRAPH_KEY (0) — separate from the
1642            // batched path's m_padded keys.
1643            match B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY) {
1644                Ok(true) => {
1645                    B::sync(&mut ctx);
1646                    return B::to_vec(&self.scratch.logits, vocab);
1647                }
1648                Ok(false) => { /* no graph yet, fall through to eager */ }
1649                Err(_) => { /* backend error or unsupported, eager */ }
1650            }
1651        }
1652
1653        let should_capture =
1654            graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
1655
1656        if should_capture {
1657            B::set_dev_state_mode(&mut ctx, true);
1658            if B::begin_graph_capture(&mut ctx).is_err() {
1659                self.graph_capture_failed = true;
1660                B::set_dev_state_mode(&mut ctx, false);
1661            }
1662        }
1663
1664        // Eager forward (records into graph if capture is active).
1665        // mem::replace needs a placeholder. B::alloc(0) was our choice but
1666        // cuMemAllocFromPoolAsync(stream, 0) can return CUDA_ERROR_INVALID_VALUE
1667        // on Blackwell after graph replay corrupts the pool state. Size-1 is
1668        // always valid and costs 2 bytes of transient VRAM per decode step.
1669        let mut residual = self
1670            .scratch
1671            .residual
1672            .take()
1673            .expect("scratch residual missing (previous call didn't restore)");
1674        let embed = self
1675            .embed
1676            .as_ref()
1677            .expect("decode_internal called on backbone-only model (no embed)");
1678        B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
1679
1680        // Per-layer wall-time profile (env-gated, off by default — adds
1681        // a B::sync between layers which serializes the pipeline). Helps
1682        // localise non-matmul bottlenecks during perf work.
1683        let layer_profile = llama_family_runtime_env().decode_layer_profile;
1684        let mut layer_times = if layer_profile {
1685            Some(Vec::with_capacity(self.cfg.num_layers))
1686        } else {
1687            None
1688        };
1689
1690        for li in 0..self.cfg.num_layers {
1691            if layer_profile {
1692                let t0 = std::time::Instant::now();
1693                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1694                B::sync(&mut ctx);
1695                let elapsed_us = t0.elapsed().as_micros() as u64;
1696                if let Some(v) = layer_times.as_mut() {
1697                    v.push(elapsed_us);
1698                }
1699            } else {
1700                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1701            }
1702        }
1703        if let Some(times) = layer_times.take() {
1704            let sum: u64 = times.iter().sum();
1705            let avg = sum / times.len() as u64;
1706            let mn = *times.iter().min().unwrap_or(&0);
1707            let mx = *times.iter().max().unwrap_or(&0);
1708            eprintln!(
1709                "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
1710                times.len(),
1711                sum / 1000,
1712                avg,
1713                mn,
1714                mx
1715            );
1716            for (i, t) in times.iter().enumerate() {
1717                eprint!("L{i}={}ms ", t / 1000);
1718                if (i + 1) % 6 == 0 {
1719                    eprintln!();
1720                }
1721            }
1722            eprintln!();
1723            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1724            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1725            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1726            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1727            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1728            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1729            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1730            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1731            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1732            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1733            eprintln!(
1734                "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
1735                attn_n,
1736                attn_us / 1000,
1737                if attn_n > 0 { attn_us / attn_n } else { 0 }
1738            );
1739            eprintln!(
1740                "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
1741                qkr_n,
1742                qkr_us / 1000,
1743                if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
1744            );
1745            eprintln!(
1746                "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
1747                mm_n,
1748                mm_us / 1000,
1749                if mm_n > 0 { mm_us / mm_n } else { 0 }
1750            );
1751            eprintln!(
1752                "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
1753                norm_n,
1754                norm_us / 1000,
1755                if norm_n > 0 { norm_us / norm_n } else { 0 }
1756            );
1757            eprintln!(
1758                "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
1759                other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
1760            );
1761        }
1762
1763        B::rms_norm(
1764            &mut ctx,
1765            &residual,
1766            &self.final_norm_w,
1767            self.cfg.rms_norm_eps,
1768            &mut self.scratch.last_normed,
1769            1,
1770            h,
1771        );
1772
1773        let lm_head = self
1774            .lm_head
1775            .as_ref()
1776            .expect("decode_internal called on backbone-only model (no lm_head)");
1777        lm_head.forward(
1778            &mut ctx,
1779            &self.scratch.last_normed,
1780            &mut self.scratch.logits,
1781            1,
1782        );
1783
1784        if should_capture && !self.graph_capture_failed {
1785            if B::end_graph_capture(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
1786                self.graph_capture_failed = true;
1787            } else {
1788                // Stream capture mode RECORDS ops into the graph without
1789                // executing them. scratch.logits still holds the previous
1790                // step's value. Replay the just-captured graph once to
1791                // actually execute and produce this step's logits. Without
1792                // this, the capture step's to_vec returns stale logits,
1793                // yielding a 1-token offset in the generated sequence.
1794                if B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
1795                    self.graph_capture_failed = true;
1796                }
1797            }
1798            B::set_dev_state_mode(&mut ctx, false);
1799        } else {
1800            self.graph_warmup += 1;
1801        }
1802
1803        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
1804        // buffer's CPU pointer without flushing the command buffer, so the
1805        // GPU must complete all pending work first or we read stale/random
1806        // data. CUDA's to_vec does an internal stream.synchronize, making
1807        // the call redundant there (~50µs/step cost), but correctness on
1808        // Metal requires the explicit flush here.
1809        B::sync(&mut ctx);
1810        self.scratch.residual = Some(residual);
1811
1812        B::to_vec(&self.scratch.logits, vocab)
1813    }
1814
1815    /// Prefill with pre-computed embeddings instead of token IDs.
1816    ///
1817    /// Used by models that embed inputs outside the LLM (e.g. Qwen3-TTS
1818    /// mixes text-embedding + codec-embedding before feeding the LM).
1819    /// Skips `final_norm` + `lm_head`; returns the last position's pre-norm
1820    /// hidden state. Caller applies its own output head.
1821    ///
1822    /// `embeds` is row-major `[seq_len * hidden_size]`, f32.
1823    pub fn prefill_from_embeds(
1824        &mut self,
1825        cache_id: &str,
1826        embeds: &[f32],
1827        seq_len: usize,
1828    ) -> Vec<f32> {
1829        let h = self.cfg.hidden_size;
1830        assert_eq!(
1831            embeds.len(),
1832            seq_len * h,
1833            "embeds length {} != seq_len * hidden_size {}",
1834            embeds.len(),
1835            seq_len * h
1836        );
1837        assert!(seq_len > 0, "prefill_from_embeds called with zero length");
1838
1839        self.ensure_scratch(seq_len);
1840        self.ensure_kv(cache_id);
1841
1842        let mut ctx = B::new_context();
1843        let mut residual = self
1844            .scratch
1845            .residual
1846            .take()
1847            .expect("scratch residual missing (previous call didn't restore)");
1848
1849        // Upload embeds → residual[0 .. seq_len*h].
1850        let embed_buf = B::from_slice(embeds);
1851        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1852
1853        for li in 0..self.cfg.num_layers {
1854            self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1855        }
1856
1857        B::copy_slice(
1858            &mut ctx,
1859            &residual,
1860            (seq_len - 1) * h,
1861            &mut self.scratch.last_hidden,
1862            0,
1863            h,
1864        );
1865        B::sync(&mut ctx);
1866        self.scratch.residual = Some(residual);
1867        B::to_vec(&self.scratch.last_hidden, h)
1868    }
1869
1870    /// Decode with a single pre-computed embedding (shape `[hidden]`).
1871    /// Returns the pre-norm hidden state for the position `pos`. Caller
1872    /// applies final norm + its own output head.
1873    pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1874        let h = self.cfg.hidden_size;
1875        assert_eq!(
1876            embed.len(),
1877            h,
1878            "embed length {} != hidden_size {}",
1879            embed.len(),
1880            h
1881        );
1882
1883        self.ensure_scratch(1);
1884        self.ensure_kv(cache_id);
1885
1886        let mut ctx = B::new_context();
1887        let mut residual = self
1888            .scratch
1889            .residual
1890            .take()
1891            .expect("scratch residual missing (previous call didn't restore)");
1892
1893        let embed_buf = B::from_slice(embed);
1894        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1895
1896        for li in 0..self.cfg.num_layers {
1897            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1898        }
1899
1900        B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1901        B::sync(&mut ctx);
1902        self.scratch.residual = Some(residual);
1903        B::to_vec(&self.scratch.last_hidden, h)
1904    }
1905
1906    /// Variant of `prefill_from_embeds` that applies `final_norm` to every
1907    /// position and returns the whole `[seq_len * hidden_size]` vector.
1908    /// Accepts `pos_offset` so callers can continue an existing sequence
1909    /// (e.g. Qwen3-TTS voice-clone: one prefill for the role prefix, a
1910    /// follow-up prefill for the reference-audio ICL block, then
1911    /// autoregressive decoding — all against the same KV cache).
1912    ///
1913    /// Used by TTS where `forward_step` in the candle-based wrapper is
1914    /// expected to return **post-norm all-positions** hidden state so
1915    /// `codec_head` can be applied on candle side.
1916    pub fn prefill_all_post_norm(
1917        &mut self,
1918        cache_id: &str,
1919        embeds: &[f32],
1920        seq_len: usize,
1921        pos_offset: usize,
1922    ) -> Vec<f32> {
1923        let h = self.cfg.hidden_size;
1924        assert_eq!(
1925            embeds.len(),
1926            seq_len * h,
1927            "embeds length {} != seq_len * hidden_size {}",
1928            embeds.len(),
1929            seq_len * h
1930        );
1931        assert!(seq_len > 0);
1932
1933        self.ensure_scratch(seq_len);
1934        self.ensure_kv(cache_id);
1935
1936        let mut ctx = B::new_context();
1937        let mut residual = self
1938            .scratch
1939            .residual
1940            .take()
1941            .expect("scratch residual missing (previous call didn't restore)");
1942
1943        let embed_buf = B::from_slice(embeds);
1944        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1945
1946        for li in 0..self.cfg.num_layers {
1947            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1948        }
1949
1950        // Apply final_norm over all seq_len positions → scratch.norm_out.
1951        B::rms_norm(
1952            &mut ctx,
1953            &residual,
1954            &self.final_norm_w,
1955            self.cfg.rms_norm_eps,
1956            &mut self.scratch.norm_out,
1957            seq_len,
1958            h,
1959        );
1960        B::sync(&mut ctx);
1961        self.scratch.residual = Some(residual);
1962        B::to_vec(&self.scratch.norm_out, seq_len * h)
1963    }
1964
1965    /// Decode-side companion to `prefill_all_post_norm`. Runs a single-token
1966    /// decode step at `pos`, applies `final_norm`, and returns the post-norm
1967    /// hidden state `[hidden_size]`.
1968    pub fn decode_post_norm_from_embed(
1969        &mut self,
1970        cache_id: &str,
1971        embed: &[f32],
1972        pos: u32,
1973    ) -> Vec<f32> {
1974        let h = self.cfg.hidden_size;
1975        assert_eq!(embed.len(), h);
1976
1977        self.ensure_scratch(1);
1978        self.ensure_kv(cache_id);
1979
1980        let mut ctx = B::new_context();
1981        let mut residual = self
1982            .scratch
1983            .residual
1984            .take()
1985            .expect("scratch residual missing (previous call didn't restore)");
1986
1987        let embed_buf = B::from_slice(embed);
1988        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1989
1990        for li in 0..self.cfg.num_layers {
1991            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1992        }
1993
1994        B::rms_norm(
1995            &mut ctx,
1996            &residual,
1997            &self.final_norm_w,
1998            self.cfg.rms_norm_eps,
1999            &mut self.scratch.last_normed,
2000            1,
2001            h,
2002        );
2003        B::sync(&mut ctx);
2004        self.scratch.residual = Some(residual);
2005        B::to_vec(&self.scratch.last_normed, h)
2006    }
2007}
2008
2009// FP16 DecoderOnlyLLM impl — full path with batched + unified-forward overrides.
2010impl<B: MoeLlmBackend> DecoderOnlyLLM for LlamaFamilyModel<B, KvFp16> {
2011    fn config(&self) -> &LlmRuntimeConfig {
2012        &self.runtime_cfg
2013    }
2014
2015    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2016        self.ensure_scratch(max_tokens);
2017        self.ensure_kv(cache_id);
2018
2019        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2020        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2021        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2022            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2023                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2024                if let Some(c0) = caches.first() {
2025                    if !c0.paged_block_indices.is_empty() {
2026                        alloc.free(&c0.paged_block_indices);
2027                    }
2028                }
2029                for c in caches.iter_mut() {
2030                    c.paged_block_indices.clear();
2031                }
2032            }
2033            self.kv_free_pool.push(caches);
2034        }
2035    }
2036
2037    fn kv_capacity(&self) -> usize {
2038        llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2039    }
2040
2041    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2042        self.prefill_internal(cache_id, tokens)
2043    }
2044
2045    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2046        self.decode_internal(cache_id, token, pos)
2047    }
2048
2049    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2050        self.decode_batch_internal(batch)
2051    }
2052
2053    fn unified_forward(
2054        &mut self,
2055        items: &[(String, Vec<u32>, usize, bool)],
2056    ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
2057        if items.is_empty() {
2058            return Ok(Vec::new());
2059        }
2060        if !B::supports_varlen_qkv() {
2061            return Err(ferrum_types::FerrumError::unsupported(
2062                "LlamaFamilyModel::unified_forward: backend lacks varlen \
2063                 QKV kernels. Engine will fall back to per-item dispatch.",
2064            ));
2065        }
2066        self.ensure_kv(&items[0].0);
2067        if self.paged_pools.is_none() {
2068            return Err(ferrum_types::FerrumError::unsupported(
2069                "LlamaFamilyModel::unified_forward: paged KV required; \
2070                 enable via FERRUM_METAL_PAGED_KV=1 (cross-backend env). \
2071                 Engine will fall back to per-item dispatch.",
2072            ));
2073        }
2074        for (cid, _, _, _) in items {
2075            self.ensure_kv(cid);
2076            if !self.kv_caches.contains_key(cid) {
2077                return Err(ferrum_types::FerrumError::resource_exhausted(format!(
2078                    "paged KV pool exhausted for cache_id={cid:?}; back off"
2079                )));
2080            }
2081        }
2082        Ok(self.unified_forward_internal(items))
2083    }
2084
2085    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2086        LlamaFamilyModel::<B, KvFp16>::forward_verify(self, cache_id, tokens)
2087    }
2088
2089    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2090        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2091            for c in caches.iter_mut() {
2092                if new_len < c.len {
2093                    c.len = new_len;
2094                }
2095            }
2096        }
2097        let mut ctx = B::new_context();
2098        B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2099        self.graph_warmup = 0;
2100        self.graph_capture_failed = false;
2101    }
2102
2103    fn release(&mut self, cache_id: &str) {
2104        let mut ctx = B::new_context();
2105        B::sync(&mut ctx);
2106        self.graph_warmup = 0;
2107        self.graph_capture_failed = false;
2108        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2109            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2110                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2111                if let Some(c0) = caches.first() {
2112                    if !c0.paged_block_indices.is_empty() {
2113                        alloc.free(&c0.paged_block_indices);
2114                    }
2115                }
2116                for c in caches.iter_mut() {
2117                    c.paged_block_indices.clear();
2118                }
2119            }
2120            self.kv_free_pool.push(caches);
2121        }
2122    }
2123
2124    fn reset(&mut self) {
2125        let mut ctx = B::new_context();
2126        B::sync(&mut ctx);
2127        B::reset_all_graphs(&mut ctx);
2128        B::sync(&mut ctx);
2129        self.graph_warmup = 0;
2130        self.graph_capture_failed = false;
2131        self.batched_graph_keys_seen.clear();
2132        self.batched_graph_warmup = 0;
2133        self.batched_graph_failed = false;
2134        self.kv_caches.clear();
2135        self.kv_free_pool.clear();
2136    }
2137}
2138
2139// INT8 DecoderOnlyLLM impl — minimal: no batched / unified-forward overrides
2140// (default trait impl falls back to per-item decode). PR D will add INT8
2141// batched paths once the kernels stabilize.
2142impl<B: MoeLlmBackend + BackendInt8KvOps> DecoderOnlyLLM for LlamaFamilyModel<B, KvInt8> {
2143    fn config(&self) -> &LlmRuntimeConfig {
2144        &self.runtime_cfg
2145    }
2146
2147    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2148        self.ensure_scratch(max_tokens);
2149        self.ensure_kv(cache_id);
2150
2151        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2152        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2153        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2154            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2155                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2156                if let Some(c0) = caches.first() {
2157                    if !c0.paged_block_indices.is_empty() {
2158                        alloc.free(&c0.paged_block_indices);
2159                    }
2160                }
2161                for c in caches.iter_mut() {
2162                    c.paged_block_indices.clear();
2163                }
2164            }
2165            self.kv_free_pool.push(caches);
2166        }
2167    }
2168
2169    fn kv_capacity(&self) -> usize {
2170        llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2171    }
2172
2173    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2174        self.prefill_internal(cache_id, tokens)
2175    }
2176
2177    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2178        self.decode_internal(cache_id, token, pos)
2179    }
2180
2181    // decode_batch + unified_forward use trait defaults (per-item fallback).
2182
2183    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2184        LlamaFamilyModel::<B, KvInt8>::forward_verify(self, cache_id, tokens)
2185    }
2186
2187    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2188        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2189            for c in caches.iter_mut() {
2190                if new_len < c.len {
2191                    c.len = new_len;
2192                }
2193            }
2194        }
2195        let mut ctx = B::new_context();
2196        B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2197        self.graph_warmup = 0;
2198        self.graph_capture_failed = false;
2199    }
2200
2201    fn release(&mut self, cache_id: &str) {
2202        let mut ctx = B::new_context();
2203        B::sync(&mut ctx);
2204        self.graph_warmup = 0;
2205        self.graph_capture_failed = false;
2206        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2207            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2208                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2209                if let Some(c0) = caches.first() {
2210                    if !c0.paged_block_indices.is_empty() {
2211                        alloc.free(&c0.paged_block_indices);
2212                    }
2213                }
2214                for c in caches.iter_mut() {
2215                    c.paged_block_indices.clear();
2216                }
2217            }
2218            self.kv_free_pool.push(caches);
2219        }
2220    }
2221
2222    fn reset(&mut self) {
2223        let mut ctx = B::new_context();
2224        B::sync(&mut ctx);
2225        B::reset_all_graphs(&mut ctx);
2226        B::sync(&mut ctx);
2227        self.graph_warmup = 0;
2228        self.graph_capture_failed = false;
2229        self.kv_caches.clear();
2230        self.kv_free_pool.clear();
2231    }
2232}
2233
2234fn build_rope_cache<B: QuantLlmBackend + BackendMoeFused>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2235    let hd = cfg.head_dim;
2236    let half = hd / 2;
2237    let max = cfg.max_seq_len;
2238    let mut cos = vec![0.0f32; max * half];
2239    let mut sin = vec![0.0f32; max * half];
2240    for pos in 0..max {
2241        for i in 0..half {
2242            let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
2243            let angle = pos as f64 * freq;
2244            cos[pos * half + i] = angle.cos() as f32;
2245            sin[pos * half + i] = angle.sin() as f32;
2246        }
2247    }
2248    RopeCache {
2249        cos: B::from_slice(&cos),
2250        sin: B::from_slice(&sin),
2251    }
2252}
2253
2254#[cfg(test)]
2255mod tests {
2256    use super::{LlamaFamilyRuntimeEnv, DEFAULT_KV_CAPACITY};
2257
2258    #[test]
2259    fn llama_family_runtime_env_parses_startup_knobs() {
2260        let env = LlamaFamilyRuntimeEnv::from_env_vars([
2261            ("FERRUM_KV_CAPACITY", "4096"),
2262            ("FERRUM_METAL_PAGED_KV", "0"),
2263            ("FERRUM_PAGED_MAX_SEQS", "64"),
2264            ("FERRUM_DECODE_OP_PROFILE", "0"),
2265            ("FERRUM_PREFILL_OP_PROFILE", ""),
2266            ("FERRUM_CUDA_GRAPH", ""),
2267            ("FERRUM_DECODE_LAYER_PROFILE", "false"),
2268        ]);
2269
2270        assert_eq!(env.kv_capacity, Some(4096));
2271        assert_eq!(env.metal_paged_kv, Some(false));
2272        assert_eq!(env.paged_max_seqs, 64);
2273        assert!(env.decode_op_profile);
2274        assert!(env.prefill_op_profile);
2275        assert!(env.cuda_graph);
2276        assert!(env.decode_layer_profile);
2277        assert_eq!(env.kv_capacity_for_model(2048), 2048);
2278    }
2279
2280    #[test]
2281    fn llama_family_runtime_env_uses_defaults_for_invalid_values() {
2282        let env = LlamaFamilyRuntimeEnv::from_env_vars([
2283            ("FERRUM_KV_CAPACITY", "bad"),
2284            ("FERRUM_PAGED_MAX_SEQS", "bad"),
2285            ("FERRUM_METAL_PAGED_KV", "1"),
2286        ]);
2287
2288        assert_eq!(env.kv_capacity, None);
2289        assert_eq!(env.metal_paged_kv, Some(true));
2290        assert_eq!(env.paged_max_seqs, 32);
2291        assert_eq!(
2292            env.kv_capacity_for_model(DEFAULT_KV_CAPACITY * 2),
2293            DEFAULT_KV_CAPACITY
2294        );
2295    }
2296}