Skip to main content

ferrum_models/models/
llama_family.rs

1//! Llama-family decoder model as explicit code.
2//!
3//! Covers all "standard Llama-style GQA + SwiGLU + RoPE" decoders:
4//!   - Llama / Llama-2 / Llama-3  (no QK-norm)
5//!   - Qwen2 / Qwen2.5            (no QK-norm, structurally Llama)
6//!   - Qwen3                      (QK-norm per head, larger rope_theta)
7//!   - Mistral                    (sliding-window attention — not yet
8//!                                 supported on the forward path; will
9//!                                 require an `AttnConfig.sliding_window`
10//!                                 field + shader support in Phase D)
11//!
12//! Variant differences are controlled by `LlamaFamilyConfig::has_qk_norm`
13//! and RoPE theta. Weight loading accepts both fused (`qkv_proj`,
14//! `gate_up_proj`) and split (`q_proj`+`k_proj`+`v_proj`,
15//! `gate_proj`+`up_proj`) projection layouts — the loader (e.g.
16//! `CandleVarBuilderLoader`) fuses split weights on load so model code
17//! sees a uniform `qkv_proj` / `gate_up_proj` Linear.
18
19use std::collections::HashMap;
20use std::sync::{atomic::AtomicU64, OnceLock};
21
22use ferrum_interfaces::kv_dtype::{KvFp16, KvInt8};
23use ferrum_kernels::backend::{
24    Backend, BackendGraph, BackendInt8KvOps, BackendMoeFused, BackendPagedKv, BackendQuantGguf,
25    BackendQuantMarlin, KvCache, KvLayer, LlmBackend, MoeLlmBackend, QuantLlmBackend,
26    MAX_LAYERS_FOR_GRAPH,
27};
28
29/// Graph cache key for the single-item decode path (`decode_internal`).
30/// Distinct from any `m_padded`-based key used by the batched path.
31pub(crate) const SINGLE_ITEM_GRAPH_KEY: u64 = 0;
32
33/// Diag counters for the batched graph dispatcher (replay vs eager).
34pub(crate) static BATCHED_GRAPH_REPLAY_COUNT: AtomicU64 = AtomicU64::new(0);
35pub(crate) static BATCHED_GRAPH_EAGER_COUNT: AtomicU64 = AtomicU64::new(0);
36
37pub(crate) static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
38pub(crate) static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
39pub(crate) static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
40pub(crate) static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
41pub(crate) static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
42pub(crate) static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
43pub(crate) static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
44pub(crate) static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
45pub(crate) static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
46pub(crate) static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
47use ferrum_quantization::{Linear, WeightLoader};
48use ferrum_types::Result;
49
50use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
51
52const DEFAULT_KV_CAPACITY: usize = 512;
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55struct LlamaFamilyRuntimeEnv {
56    kv_capacity: Option<usize>,
57    metal_paged_kv: Option<bool>,
58    paged_max_seqs: usize,
59    decode_op_profile: bool,
60    prefill_op_profile: bool,
61    cuda_graph: bool,
62    decode_layer_profile: bool,
63}
64
65impl LlamaFamilyRuntimeEnv {
66    fn from_env() -> Self {
67        Self::from_env_vars(std::env::vars())
68    }
69
70    fn from_env_vars<I, K, V>(vars: I) -> Self
71    where
72        I: IntoIterator<Item = (K, V)>,
73        K: AsRef<str>,
74        V: AsRef<str>,
75    {
76        let mut config = Self {
77            kv_capacity: None,
78            metal_paged_kv: None,
79            paged_max_seqs: 32,
80            decode_op_profile: false,
81            prefill_op_profile: false,
82            cuda_graph: false,
83            decode_layer_profile: false,
84        };
85        for (name, value) in vars {
86            let value = value.as_ref();
87            match name.as_ref() {
88                "FERRUM_KV_CAPACITY" => config.kv_capacity = value.parse::<usize>().ok(),
89                "FERRUM_METAL_PAGED_KV" => config.metal_paged_kv = Some(value != "0"),
90                "FERRUM_PAGED_MAX_SEQS" => {
91                    if let Ok(max_seqs) = value.parse::<usize>() {
92                        config.paged_max_seqs = max_seqs;
93                    }
94                }
95                "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
96                "FERRUM_PREFILL_OP_PROFILE" => config.prefill_op_profile = true,
97                "FERRUM_CUDA_GRAPH" => config.cuda_graph = true,
98                "FERRUM_DECODE_LAYER_PROFILE" => config.decode_layer_profile = true,
99                _ => {}
100            }
101        }
102        config
103    }
104
105    fn kv_capacity_for_model(&self, model_max: usize) -> usize {
106        self.kv_capacity
107            .map(|cap| cap.min(model_max))
108            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
109    }
110
111    fn paged_kv_enabled<B: BackendPagedKv>(&self) -> bool {
112        self.metal_paged_kv
113            .unwrap_or_else(|| B::supports_paged_kv())
114    }
115}
116
117fn llama_family_runtime_env() -> &'static LlamaFamilyRuntimeEnv {
118    static CONFIG: OnceLock<LlamaFamilyRuntimeEnv> = OnceLock::new();
119    CONFIG.get_or_init(LlamaFamilyRuntimeEnv::from_env)
120}
121
122#[derive(Clone, Debug, PartialEq)]
123pub enum RopeScalingConfig {
124    /// Meta Llama 3.1/3.2/3.3 long-context RoPE scaling.
125    Llama3 {
126        factor: f64,
127        low_freq_factor: f64,
128        high_freq_factor: f64,
129        original_max_position_embeddings: f64,
130    },
131}
132
133impl RopeScalingConfig {
134    pub fn llama3_default() -> Self {
135        Self::Llama3 {
136            factor: 8.0,
137            low_freq_factor: 1.0,
138            high_freq_factor: 4.0,
139            original_max_position_embeddings: 8192.0,
140        }
141    }
142}
143
144/// Full Qwen3 architecture config (everything the model code needs, not just
145/// the engine-facing subset in `LlmRuntimeConfig`).
146#[derive(Clone, Debug, PartialEq)]
147pub struct LlamaFamilyConfig {
148    pub hidden_size: usize,
149    pub intermediate_size: usize,
150    pub num_heads: usize,
151    pub num_kv_heads: usize,
152    pub head_dim: usize,
153    pub num_layers: usize,
154    pub vocab_size: usize,
155    pub max_seq_len: usize,
156    pub rms_norm_eps: f32,
157    pub rope_theta: f64,
158    pub rope_scaling: Option<RopeScalingConfig>,
159    /// GGUF LLaMA stores Q/K in the llama.cpp interleaved RoPE layout.
160    /// HF safetensors Qwen/LLaMA definitions keep the default half-split
161    /// GPT-NeoX layout used by existing Ferrum kernels.
162    pub rope_interleaved: bool,
163    /// Whether the checkpoint has `q_norm` / `k_norm` per layer. All known
164    /// Qwen3 checkpoints do; some derivatives may strip them.
165    pub has_qk_norm: bool,
166    /// Sliding-window attention size. `0` disables (full causal).
167    /// Mistral v0.1 sets 4096; Mistral v0.2+ removed the limitation (0).
168    pub sliding_window: usize,
169}
170
171impl LlamaFamilyConfig {
172    pub fn to_runtime(&self) -> LlmRuntimeConfig {
173        LlmRuntimeConfig {
174            hidden_size: self.hidden_size,
175            num_layers: self.num_layers,
176            num_kv_heads: self.num_kv_heads,
177            head_dim: self.head_dim,
178            vocab_size: self.vocab_size,
179            max_seq_len: self.max_seq_len,
180        }
181    }
182
183    /// Build config from a `ModelDefinition`, shared field extraction.
184    /// Variant-specific constructors below set `has_qk_norm` and fall back
185    /// to different `rope_theta` defaults when the checkpoint doesn't set one.
186    fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
187        let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
188        let head_dim = def
189            .extra_params
190            .get("head_dim")
191            .and_then(|v| v.as_u64())
192            .map(|v| v as usize)
193            .unwrap_or(def.hidden_size / def.num_attention_heads);
194        // Mistral / Gemma: "sliding_window" may be null (v0.2+) or a positive
195        // integer (v0.1). Non-null value passes through; missing/null → 0.
196        let sliding_window = def
197            .extra_params
198            .get("sliding_window")
199            .and_then(|v| v.as_u64())
200            .map(|v| v as usize)
201            .unwrap_or(0);
202
203        LlamaFamilyConfigBase {
204            hidden_size: def.hidden_size,
205            intermediate_size: def.intermediate_size,
206            num_heads: def.num_attention_heads,
207            num_kv_heads,
208            head_dim,
209            num_layers: def.num_hidden_layers,
210            vocab_size: def.vocab_size,
211            max_seq_len: def.max_position_embeddings,
212            rms_norm_eps: def.norm_eps as f32,
213            rope_theta_opt: def.rope_theta,
214            rope_scaling: rope_scaling_from_model_def(def),
215            sliding_window,
216        }
217    }
218
219    fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
220        Self {
221            hidden_size: b.hidden_size,
222            intermediate_size: b.intermediate_size,
223            num_heads: b.num_heads,
224            num_kv_heads: b.num_kv_heads,
225            head_dim: b.head_dim,
226            num_layers: b.num_layers,
227            vocab_size: b.vocab_size,
228            max_seq_len: b.max_seq_len,
229            rms_norm_eps: b.rms_norm_eps,
230            rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
231            rope_scaling: b.rope_scaling,
232            rope_interleaved: false,
233            has_qk_norm,
234            sliding_window: b.sliding_window,
235        }
236    }
237
238    /// Qwen3: QK-norm on, rope_theta default 1e6.
239    pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
240        Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
241    }
242
243    /// Llama / Llama-2 / Llama-3: no QK-norm; rope_theta varies by version
244    /// (10k for Llama-2, 500k for Llama-3.1+) — use the checkpoint value or
245    /// fall back to the most common modern value.
246    pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
247        Self::from_base(Self::from_def_base(def), 500_000.0, false)
248    }
249
250    /// Qwen2 / Qwen2.5: structurally Llama; no QK-norm; rope_theta default 1e6.
251    pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
252        Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
253    }
254
255    /// Mistral: no QK-norm; `rope_theta` commonly 10_000 (v0.1 / v0.2).
256    /// Picks up `sliding_window` from the checkpoint's config.json
257    /// (Mistral v0.1: 4096; Mistral v0.2+: 0 / null).
258    pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
259        Self::from_base(Self::from_def_base(def), 10_000.0, false)
260    }
261}
262
263struct LlamaFamilyConfigBase {
264    hidden_size: usize,
265    intermediate_size: usize,
266    num_heads: usize,
267    num_kv_heads: usize,
268    head_dim: usize,
269    num_layers: usize,
270    vocab_size: usize,
271    max_seq_len: usize,
272    rms_norm_eps: f32,
273    rope_theta_opt: Option<f64>,
274    rope_scaling: Option<RopeScalingConfig>,
275    sliding_window: usize,
276}
277
278fn rope_scaling_from_model_def(
279    def: &crate::definition::ModelDefinition,
280) -> Option<RopeScalingConfig> {
281    let value = def.extra_params.get("rope_scaling")?;
282    let obj = value.as_object()?;
283    let rope_type = obj
284        .get("rope_type")
285        .or_else(|| obj.get("type"))
286        .and_then(|v| v.as_str())?;
287    if rope_type != "llama3" {
288        return None;
289    }
290    let factor = json_f64(obj.get("factor"))?;
291    let low_freq_factor = json_f64(obj.get("low_freq_factor"))?;
292    let high_freq_factor = json_f64(obj.get("high_freq_factor"))?;
293    let original_max_position_embeddings = json_f64(obj.get("original_max_position_embeddings"))
294        .or_else(|| {
295            def.extra_params
296                .get("original_max_position_embeddings")
297                .and_then(|v| json_f64(Some(v)))
298        })
299        .unwrap_or(8192.0);
300    if factor <= 0.0
301        || low_freq_factor <= 0.0
302        || high_freq_factor <= low_freq_factor
303        || original_max_position_embeddings <= 0.0
304    {
305        return None;
306    }
307    Some(RopeScalingConfig::Llama3 {
308        factor,
309        low_freq_factor,
310        high_freq_factor,
311        original_max_position_embeddings,
312    })
313}
314
315fn json_f64(value: Option<&serde_json::Value>) -> Option<f64> {
316    match value? {
317        serde_json::Value::Number(n) => n.as_f64(),
318        _ => None,
319    }
320}
321
322/// Per-layer weights. `Box<dyn Linear<B>>` means each projection can be
323/// Dense / GPTQ / AWQ / GGUF without the surrounding code caring.
324pub struct LlamaFamilyLayer<B: QuantLlmBackend + BackendMoeFused> {
325    pub input_ln_w: B::Buffer,
326    pub qkv_proj: Box<dyn Linear<B>>,
327    /// QK-norm weight per head: `[head_dim]`. Optional for non-Qwen3 derivatives.
328    pub q_norm_w: Option<B::Buffer>,
329    pub k_norm_w: Option<B::Buffer>,
330    pub o_proj: Box<dyn Linear<B>>,
331    pub post_ln_w: B::Buffer,
332    pub gate_up_proj: Box<dyn Linear<B>>,
333    pub down_proj: Box<dyn Linear<B>>,
334}
335
336/// Precomputed RoPE cos/sin tables (shape `[max_seq, head_dim / 2]` each).
337pub struct RopeCache<B: QuantLlmBackend + BackendMoeFused> {
338    pub cos: B::Buffer,
339    pub sin: B::Buffer,
340}
341
342/// Reusable per-layer scratch buffers sized for `max_tokens` tokens of a
343/// single forward pass (prefill or decode step).
344///
345/// Sized lazily on first use so tiny decode steps don't pay for prefill-sized
346/// buffers. Grows monotonically when a larger prefill arrives.
347pub struct LlamaFamilyScratch<B: QuantLlmBackend + BackendMoeFused> {
348    /// Residual stream — wrapped in Option so decode_internal can
349    /// `.take()` it without needing an alloc placeholder.
350    ///
351    /// Why this matters for graph capture: the old pattern was
352    /// `mem::replace(&mut scratch.residual, B::alloc(1))` which creates a
353    /// 1-element buffer at every decode step. When graph capture is on,
354    /// that alloc-during-capture + drop-after-capture pair surfaces as
355    /// cuMemFreeAsync(INVALID_VALUE) because the free tries to release a
356    /// pointer the captured graph may still reference. Option::take leaves
357    /// None and moves the real buffer into a local — no spurious alloc.
358    pub residual: Option<B::Buffer>,
359    pub norm_out: B::Buffer,
360    pub qkv_out: B::Buffer,
361    // ── Per-item scratch for batched decode path ──────────────────────
362    // decode_batch_internal runs tokens=M batched ops for the GEMM-heavy
363    // half (norm, qkv_proj, split_qkv, o_proj, post_norm, gate_up, silu,
364    // down, residual_add) but must loop per-item for rope + KV append +
365    // attention (each item has its own KV cache at a different kv_len).
366    // These single-item buffers hold item i's slice during that loop.
367    /// Item-scope q_buf slice, sized `q_dim`.
368    pub q_single: B::Buffer,
369    pub k_single: B::Buffer,
370    pub v_single: B::Buffer,
371    pub q_head_major_single: B::Buffer,
372    pub k_head_major_single: B::Buffer,
373    pub v_head_major_single: B::Buffer,
374    pub attn_head_major_single: B::Buffer,
375    pub attn_flat_single: B::Buffer,
376    /// Batched logits output, sized `max_tokens * vocab_size`. Used only
377    /// in decode_batch; prefill/single-decode use the regular `logits`.
378    pub batch_logits: B::Buffer,
379    /// Token-major Q/K/V right after `split_qkv`. Stride: heads * hd per row.
380    pub q_buf: B::Buffer,
381    pub k_buf: B::Buffer,
382    pub v_buf: B::Buffer,
383    /// Head-major Q produced by `qk_norm_rope` — fed into `flash_attention`.
384    pub q_head_major: B::Buffer,
385    /// Head-major K/V staging — produced by `qk_norm_rope`, consumed by
386    /// `kv_cache_append_head_major` (no reuse after append).
387    pub k_head_major: B::Buffer,
388    pub v_head_major: B::Buffer,
389    /// Head-major attention output from `flash_attention`.
390    pub attn_head_major_out: B::Buffer,
391    /// Token-major attention output after `transpose_head_to_token`.
392    pub attn_flat: B::Buffer,
393    pub o_proj_out: B::Buffer,
394    pub gate_up_out: B::Buffer,
395    pub silu_out: B::Buffer,
396    pub mlp_out: B::Buffer,
397    /// Paged batched dispatch scratch (Phase 4b). Sized for
398    /// `FERRUM_PAGED_MAX_SEQS × q_dim` so multi-seq decode can fan
399    /// in M items' Q into a single buffer for one batched
400    /// `paged_decode_attention(num_seqs=M)` call. `None` when paged
401    /// mode is off.
402    pub paged_batch_q: Option<B::Buffer>,
403    pub paged_batch_o: Option<B::Buffer>,
404    /// Stacked per-seq block tables for batched paged dispatch.
405    /// Layout: `[max_M, max_blocks_per_seq]` u32. Written
406    /// host-side per decode_batch step.
407    pub paged_batch_block_tables: Option<B::Buffer>,
408    /// Stacked per-seq context lengths for batched paged dispatch
409    /// (`[max_M]` u32).
410    pub paged_batch_context_lens: Option<B::Buffer>,
411    /// `max_blocks_per_seq` value baked into the stacked block_tables
412    /// stride. Set when `paged_batch_block_tables` is allocated.
413    pub paged_max_blocks_per_seq: usize,
414    /// Engine-side max concurrent sequences (= `FERRUM_PAGED_MAX_SEQS`).
415    /// Pinned at the first `enable_paged_batch` so the unified-forward
416    /// scratch sizes (`unified_cu_seqlens_q`, `unified_pos_offsets`,
417    /// `unified_block_tables`, `unified_packed_*`) are big enough for
418    /// any subsequent batch up to that bound.
419    pub paged_max_seqs: usize,
420    /// Per-item RoPE positions for the batched-decode path (`[max_M]`
421    /// i32 / u32). Written host-side once per batched-decode step from
422    /// each request's `pos` field, read by the batched
423    /// `qk_norm_rope_batched_per_item` CUDA kernel.
424    pub batch_positions: B::Buffer,
425    /// Per-item input token ids for the batched-decode path (`[max_M]`
426    /// u32). Written once per call before forward; read by the
427    /// graph-capture-friendly `embedding_lookup_batched_dyn` variant.
428    pub batch_tokens: B::Buffer,
429    /// Per-item KV-cache length BEFORE this step's kv_append
430    /// (`[max_M]` u32). Used by `kv_cache_append_batched_per_cache_dyn`
431    /// to write at the right slot for graph replay.
432    pub batch_kv_lens_pre: B::Buffer,
433    /// Per-item KV-cache length AFTER this step's kv_append
434    /// (`[max_M]` u32 = pre + 1). Used by
435    /// `flash_attention_batched_per_cache_dyn` for the attention
436    /// reduce window length.
437    pub batch_kv_lens_post: B::Buffer,
438    /// Output buffers for the batched per-item qk_norm_rope kernel.
439    /// Same shape as q_buf / k_buf / v_buf — separate so the kernel
440    /// API can take `&input` and `&mut output` without aliasing.
441    pub q_normed_batched: B::Buffer,
442    pub k_normed_batched: B::Buffer,
443    pub v_normed_batched: B::Buffer,
444
445    // ── Unified mixed-batch scratch (chunked-prefill path; Step 5b) ─────
446    // Buffers sized for `M_total = sum(items[i].q_tokens.len())`. Grown
447    // on demand by `ensure_unified_scratch(M_total_max)`. Used only by
448    // the new `unified_forward_internal`; legacy `forward_layer_batched_decode`
449    // continues to use the per-item-stride scratch above.
450    pub unified_capacity: usize, // current allocated M_total slots
451    pub unified_residual: Option<B::Buffer>,
452    pub unified_norm_out: Option<B::Buffer>,
453    pub unified_qkv_out: Option<B::Buffer>,
454    pub unified_packed_q: Option<B::Buffer>,
455    pub unified_attn_out: Option<B::Buffer>,
456    pub unified_o_proj_out: Option<B::Buffer>,
457    pub unified_gate_up_out: Option<B::Buffer>,
458    pub unified_silu_out: Option<B::Buffer>,
459    pub unified_mlp_out: Option<B::Buffer>,
460    /// Per-item index buffers (i32-stored-as-f16): cu_seqlens_q is
461    /// length `max_seqs+1`, pos_offsets is `max_seqs`, block_tables is
462    /// `max_seqs * max_blocks_per_seq`. Sized once at first use to
463    /// match `paged_batch_*` capacity since they share `max_seqs`.
464    pub unified_cu_seqlens_q: Option<B::Buffer>,
465    pub unified_pos_offsets: Option<B::Buffer>,
466    pub unified_block_tables: Option<B::Buffer>,
467    /// Packed last-token hidden states for is_final_chunk items
468    /// (`[num_sampled, h]`). Used as input to lm_head.
469    pub unified_packed_normed: Option<B::Buffer>,
470    /// Packed logits output (`[num_sampled, vocab]`).
471    pub unified_packed_logits: Option<B::Buffer>,
472    /// Last token's hidden state (`[h]`). For prefill this is populated via
473    /// `copy_slice(residual, (seq_len-1)*h, ..)`; for decode `residual` already
474    /// holds only 1 row so `last_hidden` is unused on that path.
475    pub last_hidden: B::Buffer,
476    /// Final-norm output for the last token (`[h]`).
477    pub last_normed: B::Buffer,
478    /// lm_head logits (`[vocab]`).
479    pub logits: B::Buffer,
480    /// The max tokens-per-step this scratch has been sized for.
481    pub max_tokens: usize,
482}
483
484impl<B: QuantLlmBackend + BackendMoeFused> LlamaFamilyScratch<B> {
485    fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
486        let h = cfg.hidden_size;
487        let im = cfg.intermediate_size;
488        let q_dim = cfg.num_heads * cfg.head_dim;
489        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
490        let qkv_dim = q_dim + 2 * kv_dim;
491        let t = max_tokens;
492        Self {
493            residual: Some(B::alloc(t * h)),
494            norm_out: B::alloc(t * h),
495            qkv_out: B::alloc(t * qkv_dim),
496            q_buf: B::alloc(t * q_dim),
497            k_buf: B::alloc(t * kv_dim),
498            v_buf: B::alloc(t * kv_dim),
499            q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
500            k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
501            v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
502            attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
503            attn_flat: B::alloc(t * q_dim),
504            o_proj_out: B::alloc(t * h),
505            gate_up_out: B::alloc(t * 2 * im),
506            silu_out: B::alloc(t * im),
507            mlp_out: B::alloc(t * h),
508            last_hidden: B::alloc(h),
509            last_normed: B::alloc(h),
510            logits: B::alloc(cfg.vocab_size),
511            q_single: B::alloc(q_dim),
512            k_single: B::alloc(kv_dim),
513            v_single: B::alloc(kv_dim),
514            q_head_major_single: B::alloc(q_dim),
515            k_head_major_single: B::alloc(kv_dim),
516            v_head_major_single: B::alloc(kv_dim),
517            attn_head_major_single: B::alloc(q_dim),
518            attn_flat_single: B::alloc(q_dim),
519            batch_logits: B::alloc(t * cfg.vocab_size),
520            // Paged batched dispatch scratch. None until `enable_paged_batch`
521            // is called from `ensure_kv` once the model knows max_seqs +
522            // max_blocks_per_seq. This avoids paying the alloc cost when
523            // paged mode is off.
524            paged_batch_q: None,
525            paged_batch_o: None,
526            paged_batch_block_tables: None,
527            paged_batch_context_lens: None,
528            paged_max_blocks_per_seq: 0,
529            paged_max_seqs: 0,
530            batch_positions: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
531            batch_tokens: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
532            batch_kv_lens_pre: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
533            batch_kv_lens_post: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
534            q_normed_batched: B::alloc(t * q_dim),
535            k_normed_batched: B::alloc(t * kv_dim),
536            v_normed_batched: B::alloc(t * kv_dim),
537            unified_capacity: 0,
538            unified_residual: None,
539            unified_norm_out: None,
540            unified_qkv_out: None,
541            unified_packed_q: None,
542            unified_attn_out: None,
543            unified_o_proj_out: None,
544            unified_gate_up_out: None,
545            unified_silu_out: None,
546            unified_mlp_out: None,
547            unified_cu_seqlens_q: None,
548            unified_pos_offsets: None,
549            unified_block_tables: None,
550            unified_packed_normed: None,
551            unified_packed_logits: None,
552            max_tokens: t,
553        }
554    }
555
556    /// Grow unified-path scratch buffers to accommodate `m_total` query
557    /// tokens. Called lazily from `unified_forward_internal` so single-
558    /// path workloads (no chunked prefill) don't pay the alloc cost.
559    pub(crate) fn ensure_unified_scratch(
560        &mut self,
561        cfg: &LlamaFamilyConfig,
562        m_total: usize,
563        max_seqs: usize,
564        max_blocks_per_seq: usize,
565    ) {
566        if m_total <= self.unified_capacity
567            && self.unified_residual.is_some()
568            && self.unified_cu_seqlens_q.is_some()
569        {
570            return;
571        }
572        let cap = m_total.max(self.unified_capacity).max(1);
573        let h = cfg.hidden_size;
574        let q_dim = cfg.num_heads * cfg.head_dim;
575        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
576        let qkv_dim = q_dim + 2 * kv_dim;
577        let im = cfg.intermediate_size;
578        let v = cfg.vocab_size;
579        self.unified_residual = Some(B::alloc(cap * h));
580        self.unified_norm_out = Some(B::alloc(cap * h));
581        self.unified_qkv_out = Some(B::alloc(cap * qkv_dim));
582        self.unified_packed_q = Some(B::alloc(cap * q_dim));
583        self.unified_attn_out = Some(B::alloc(cap * q_dim));
584        self.unified_o_proj_out = Some(B::alloc(cap * h));
585        self.unified_gate_up_out = Some(B::alloc(cap * 2 * im));
586        self.unified_silu_out = Some(B::alloc(cap * im));
587        self.unified_mlp_out = Some(B::alloc(cap * h));
588        if self.unified_cu_seqlens_q.is_none() {
589            self.unified_cu_seqlens_q = Some(B::alloc_typed(
590                ferrum_kernels::backend::Dtype::U32,
591                max_seqs + 1,
592            ));
593            self.unified_pos_offsets = Some(B::alloc_typed(
594                ferrum_kernels::backend::Dtype::U32,
595                max_seqs,
596            ));
597            self.unified_block_tables = Some(B::alloc_typed(
598                ferrum_kernels::backend::Dtype::U32,
599                max_seqs * max_blocks_per_seq,
600            ));
601            self.unified_packed_normed = Some(B::alloc(max_seqs * h));
602            self.unified_packed_logits = Some(B::alloc(max_seqs * v));
603        }
604        self.unified_capacity = cap;
605    }
606
607    /// Allocate scratch for batched paged dispatch (Phase 4b). Called
608    /// lazily from `ensure_kv` once paged mode is enabled and we know
609    /// the pool dimensions. Idempotent.
610    fn enable_paged_batch(
611        &mut self,
612        cfg: &LlamaFamilyConfig,
613        max_seqs: usize,
614        max_blocks_per_seq: usize,
615    ) {
616        if self.paged_batch_q.is_some() {
617            return;
618        }
619        let q_dim = cfg.num_heads * cfg.head_dim;
620        self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
621        self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
622        self.paged_batch_block_tables = Some(B::alloc_typed(
623            ferrum_kernels::backend::Dtype::U32,
624            max_seqs * max_blocks_per_seq,
625        ));
626        self.paged_batch_context_lens = Some(B::alloc_typed(
627            ferrum_kernels::backend::Dtype::U32,
628            max_seqs,
629        ));
630        self.paged_max_blocks_per_seq = max_blocks_per_seq;
631        self.paged_max_seqs = max_seqs;
632    }
633}
634
635/// Qwen3 model — decoder-only LLM, one per (backend, weights) combination.
636///
637/// Holds all parameters, scratch space, RoPE cache, and per-sequence KV caches.
638///
639/// `B: BackendGraph + BackendQuantMarlin + BackendQuantGguf` because the decode hot path uses CUDA Graph capture/replay
640/// when the backend supports it; non-graph backends (Metal/CPU) inherit no-op
641/// defaults, so this bound is satisfied by every concrete `Backend`.
642///
643/// `K: KvDtypeKind = KvFp16` (Dim 5): selects the KV cache element type.
644/// `K = KvFp16` constructs only `LayerKvCache::Fp16(...)` variants;
645/// `K = KvInt8` constructs only `LayerKvCache::Int8(...)` variants. The
646/// `B: BackendKvDtype<KvInt8>` bound is structurally required by the
647/// enum and is satisfied by all backends via stub impls (Cpu/Metal) or
648/// the real impl (CUDA).
649pub struct LlamaFamilyModel<B: MoeLlmBackend, K: KvLayer<B> = KvFp16> {
650    pub cfg: LlamaFamilyConfig,
651    pub runtime_cfg: LlmRuntimeConfig,
652
653    /// Token embedding table. `None` for backbone-only models (e.g. the
654    /// Qwen3-TTS Talker, which embeds inputs externally and feeds via
655    /// `prefill_from_embeds`).
656    pub embed: Option<B::Buffer>,
657    pub layers: Vec<LlamaFamilyLayer<B>>,
658    pub final_norm_w: B::Buffer,
659    /// LM output head. `None` for backbone-only models.
660    pub lm_head: Option<Box<dyn Linear<B>>>,
661
662    pub rope: RopeCache<B>,
663    pub scratch: LlamaFamilyScratch<B>,
664
665    /// Per-sequence KV caches, one `Vec<KvCache<B>>` of length `num_layers`.
666    ///
667    /// Two layouts overlay this same map:
668    /// - **Contiguous mode** (default): each cache holds its own
669    ///   `[num_kv_heads, capacity, head_dim]` k/v buffers.
670    /// - **Paged mode** (`FERRUM_METAL_PAGED_KV=1`): k/v are unused
671    ///   placeholders; the real K/V live in [`Self::paged_pools`] and
672    ///   the cache's `block_table` + `context_lens` index into them.
673    pub kv_caches: HashMap<String, Vec<K::Layer>>,
674    /// Free pool of pre-allocated KV cache slots. Released caches return
675    /// here instead of being dropped, so their device pointers stay valid
676    /// across requests — critical for graph capture (pointers baked into
677    /// the captured graph would otherwise dangle).
678    kv_free_pool: Vec<Vec<K::Layer>>,
679
680    // ── Paged-KV multi-seq state (Phase 4) ─────────────────────────────
681    //
682    // Only populated when `FERRUM_METAL_PAGED_KV=1`. When set, every
683    // `kv_caches` entry becomes a "view" into the shared pool: its
684    // `k` / `v` buffers are placeholders; reads / writes go through
685    // `paged_pools[layer].k` / `.v` indexed via the cache's
686    // `block_table`. Multiple cache_ids share the same pool, with
687    // disjoint physical block sets owned by `paged_block_alloc`.
688    //
689    /// Shared K/V pools, one per layer. Sized at model load time for the
690    /// configured `MAX_RUNNING_REQUESTS × max_blocks_per_seq` blocks.
691    pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
692    /// Block allocator hands out physical block indices from the pool.
693    /// `Mutex` because the engine can call `ensure_kv` / `release_kv`
694    /// from multiple threads in concurrent serving.
695    pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
696    /// Paged-batch dispatch dimensions `(max_seqs, max_blocks_per_seq)`,
697    /// pinned at the first `ensure_kv` when paged-KV is on. Stored on
698    /// the model (not on scratch) so `ensure_scratch`'s realloc can
699    /// re-call `enable_paged_batch` after wiping scratch's
700    /// `paged_batch_block_tables` / `paged_batch_q` etc.
701    pub paged_dims: Option<(usize, usize)>,
702
703    // ── Graph capture state (CUDA only; harmless no-op on other backends) ──
704    /// Count of eager decode steps run so far. After `GRAPH_WARMUP`, the
705    /// next step captures the decode flow as a graph.
706    pub(crate) graph_warmup: usize,
707    /// True if capture was attempted but failed (e.g. backend doesn't
708    /// support graph capture). Stops further attempts, falls back to eager.
709    pub(crate) graph_capture_failed: bool,
710    /// Same warmup counter for the batched-decode path.
711    pub(crate) batched_graph_warmup: usize,
712    /// True if batched graph capture failed.
713    pub(crate) batched_graph_failed: bool,
714    /// Set of `m_padded` values (as u64 graph keys) for which a batched
715    /// graph has been captured. Multi-slot via cuda.rs's HashMap-keyed
716    /// graph cache — different batch shapes don't thrash a single slot.
717    pub(crate) batched_graph_keys_seen: std::collections::HashSet<u64>,
718    /// Cache IDs for which device-pointer scratch is currently populated.
719    /// Populate only re-runs when the batch composition changes (new
720    /// requests joined / requests finished). Hot-path optimization:
721    /// avoids 3 sync cuMemcpyHtoD_v2's per decode token (~5% TPOT).
722    pub(crate) batched_pointers_for: Option<Vec<String>>,
723    /// CUDA-graph state for the unified_forward path. Mirrors the
724    /// `batched_graph_*` triple but keyed on `(m_total, num_seqs)`
725    /// so different concurrency levels each get their own cached
726    /// graph instead of thrashing a single slot.
727    pub(crate) unified_graph_warmup: usize,
728    pub(crate) unified_graph_failed: bool,
729    pub(crate) unified_graph_keys_seen: std::collections::HashSet<u64>,
730}
731
732impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyModel<B, K> {
733    /// Build a Qwen3 model from weights provided by the loader.
734    ///
735    /// The loader decides per-projection whether to instantiate DenseLinear,
736    /// GptqLinear, etc. — this code doesn't care.
737    pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
738        // Invalidate any graph from a previously-loaded model. The captured
739        // graph references the old model's scratch buffers; a fresh model
740        // gets fresh scratch, so reusing the graph would read/write freed
741        // pointers. Matters for test suites where multiple models coexist.
742        {
743            let mut ctx = B::new_context();
744            B::reset_all_graphs(&mut ctx);
745        }
746        let rope = build_rope_cache::<B>(&cfg);
747        let scratch = LlamaFamilyScratch::alloc(&cfg, 1); // decode-sized; prefill resizes
748
749        // Embedding: plain tensor (no projection math, just lookup).
750        let embed = loader.load_tensor("model.embed_tokens.weight")?;
751
752        // Per-layer weights.
753        let mut layers = Vec::with_capacity(cfg.num_layers);
754        for li in 0..cfg.num_layers {
755            let prefix = format!("model.layers.{li}");
756            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
757            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
758            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
759            let post_ln_w =
760                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
761            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
762            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
763
764            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
765                let q = loader
766                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
767                    .ok();
768                let k = loader
769                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
770                    .ok();
771                (q, k)
772            } else {
773                (None, None)
774            };
775
776            layers.push(LlamaFamilyLayer {
777                input_ln_w,
778                qkv_proj,
779                q_norm_w,
780                k_norm_w,
781                o_proj,
782                post_ln_w,
783                gate_up_proj,
784                down_proj,
785            });
786        }
787
788        let final_norm_w = loader.load_tensor("model.norm.weight")?;
789
790        // LM head: either dedicated `lm_head.weight` or tied to embedding.
791        // Many models (Qwen3-4B, Llama-3.2-1B, some Qwen2.5) use TIED
792        // embeddings — lm_head shares weights with model.embed_tokens. When
793        // no dedicated lm_head tensor exists, re-load the embed tensor as a
794        // DenseLinear. This duplicates the buffer (memory cost = vocab*h*2
795        // bytes, e.g. ~770MB for Qwen3-4B) but keeps the Linear trait's
796        // owned-weights invariant. Sharing via Arc is a future optimisation.
797        let lm_head = if loader.has_tensor("lm_head.weight") {
798            loader.load_linear("lm_head")?
799        } else {
800            tracing::info!(
801                "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
802            );
803            let as_linear = loader.load_linear("model.embed_tokens")?;
804            // Sanity check: shape must be [vocab, hidden].
805            if as_linear.out_features() != cfg.vocab_size
806                || as_linear.in_features() != cfg.hidden_size
807            {
808                return Err(ferrum_types::FerrumError::model(format!(
809                    "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
810                    as_linear.out_features(),
811                    as_linear.in_features(),
812                    cfg.vocab_size,
813                    cfg.hidden_size
814                )));
815            }
816            as_linear
817        };
818
819        let runtime_cfg = cfg.to_runtime();
820        Ok(Self {
821            cfg,
822            runtime_cfg,
823            embed: Some(embed),
824            layers,
825            final_norm_w,
826            lm_head: Some(lm_head),
827            rope,
828            scratch,
829            kv_caches: HashMap::new(),
830            kv_free_pool: Vec::new(),
831            paged_pools: None,
832            paged_block_alloc: None,
833            paged_dims: None,
834            graph_warmup: 0,
835            graph_capture_failed: false,
836            batched_graph_warmup: 0,
837            batched_graph_failed: false,
838            batched_graph_keys_seen: std::collections::HashSet::new(),
839            batched_pointers_for: None,
840            unified_graph_warmup: 0,
841            unified_graph_failed: false,
842            unified_graph_keys_seen: std::collections::HashSet::new(),
843        })
844    }
845
846    /// Build a backbone-only Qwen3 transformer stack (no embed, no lm_head).
847    ///
848    /// Intended for composing the transformer inside a larger model where
849    /// embedding and output-head logic differs from the standard LLM path —
850    /// e.g. Qwen3-TTS Talker uses dual text/codec embeddings with a projection
851    /// MLP, and a codec_head output. The caller drives forward via
852    /// `prefill_from_embeds` / `decode_from_embed`.
853    ///
854    /// Loader must provide: per-layer weights under `model.layers.{i}.*` and
855    /// the final `model.norm.weight`. `model.embed_tokens` and `lm_head`
856    /// are NOT read.
857    pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
858        // See `new` — invalidate stale graph referring to prior model's scratch.
859        {
860            let mut ctx = B::new_context();
861            B::reset_all_graphs(&mut ctx);
862        }
863        let rope = build_rope_cache::<B>(&cfg);
864        let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
865
866        let mut layers = Vec::with_capacity(cfg.num_layers);
867        for li in 0..cfg.num_layers {
868            let prefix = format!("model.layers.{li}");
869            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
870            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
871            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
872            let post_ln_w =
873                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
874            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
875            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
876
877            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
878                let q = loader
879                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
880                    .ok();
881                let k = loader
882                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
883                    .ok();
884                (q, k)
885            } else {
886                (None, None)
887            };
888
889            layers.push(LlamaFamilyLayer {
890                input_ln_w,
891                qkv_proj,
892                q_norm_w,
893                k_norm_w,
894                o_proj,
895                post_ln_w,
896                gate_up_proj,
897                down_proj,
898            });
899        }
900
901        let final_norm_w = loader.load_tensor("model.norm.weight")?;
902
903        let runtime_cfg = cfg.to_runtime();
904        Ok(Self {
905            cfg,
906            runtime_cfg,
907            embed: None,
908            layers,
909            final_norm_w,
910            lm_head: None,
911            rope,
912            scratch,
913            kv_caches: HashMap::new(),
914            kv_free_pool: Vec::new(),
915            paged_pools: None,
916            paged_block_alloc: None,
917            paged_dims: None,
918            graph_warmup: 0,
919            graph_capture_failed: false,
920            batched_graph_warmup: 0,
921            batched_graph_failed: false,
922            batched_graph_keys_seen: std::collections::HashSet::new(),
923            batched_pointers_for: None,
924            unified_graph_warmup: 0,
925            unified_graph_failed: false,
926            unified_graph_keys_seen: std::collections::HashSet::new(),
927        })
928    }
929
930    /// Grow scratch buffers if `tokens` exceeds the current sizing.
931    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
932        if self.scratch.max_tokens < tokens {
933            // Any captured decode graph holds pointers to the old scratch
934            // buffers; those are about to be freed. Invalidate ALL captured
935            // graphs (both single-item and per-m_padded batched) — every
936            // captured kernel-arg pointer into scratch is stale.
937            {
938                let mut ctx = B::new_context();
939                B::reset_all_graphs(&mut ctx);
940            }
941            self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
942            self.graph_warmup = 0;
943            self.graph_capture_failed = false;
944            self.batched_graph_keys_seen.clear();
945            self.batched_graph_warmup = 0;
946            self.batched_graph_failed = false;
947            self.unified_graph_keys_seen.clear();
948            self.unified_graph_warmup = 0;
949            self.unified_graph_failed = false;
950            // Realloc wiped paged_batch_*. Re-enable using the dims
951            // pinned at first ensure_kv. Without this, the next
952            // `forward_layer_batched_decode` panics on
953            // `paged_batch_block_tables missing`.
954            if let Some((max_seqs, max_blocks_per_seq)) = self.paged_dims {
955                self.scratch
956                    .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
957            }
958        }
959    }
960
961    /// Ensure per-layer KV caches exist for `cache_id`, pre-allocated to
962    /// `max_seq_len` slots per head. Enables the in-place
963    /// `kv_cache_append_head_major` path — no realloc per layer.
964    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
965        if self.kv_caches.contains_key(cache_id) {
966            return;
967        }
968        let nkv = self.cfg.num_kv_heads;
969        let hd = self.cfg.head_dim;
970        // KV capacity defaults to a chat-friendly 4096 to keep the working
971        // set sane on a 32 GB Mac (a 48-layer / 4-kv-head / 128-head_dim
972        // model spends ~786 MB on 4096 KV slots, vs ~6 GB on the model's
973        // declared 32K which would push the 17 GB Qwen3-30B-A3B model into
974        // swap). `FERRUM_KV_CAPACITY=N` overrides; clamp to the model's
975        // declared max so we never lie to the model about its window.
976        let model_max = self.cfg.max_seq_len;
977        // 512 in 0.7.2 — matches the value used in
978        // docs/bench/macos-2026-05-02 to get the published numbers.
979        // pre-0.7.2 default of 4096 was safe only because paged-KV was
980        // opt-in (pool wasn't allocated). With paged-KV now on by
981        // default + MAX_SEQS=32, the pool occupies physical memory:
982        // ~3 GB on Qwen3-30B-A3B Q4_K_M leaves 18 GB weights + 3 GB pool
983        // = 21 GB, fits comfortably on a 32 GB Mac. Long-context users
984        // can `FERRUM_KV_CAPACITY=4096` and accept lower max_seqs.
985        let runtime_env = llama_family_runtime_env();
986        let max = runtime_env.kv_capacity_for_model(model_max);
987
988        // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches every cache
989        // for this model into block-table-indirect layout. Kernels from
990        // PR #68 (decode read) + PR #69 (decode write) handle the
991        // indirect addressing; the LlamaFamily decode path below picks
992        // them up automatically by checking `cache.block_size > 0`.
993        //
994        // Pool sizing: round capacity up to a multiple of block_size,
995        // identity-assign logical→physical block. Memory footprint is
996        // the same as contiguous (within block_size rounding); the
997        // benefit only shows up under multi-seq sharing in Phase 4+.
998        // Default ON when the backend supports paged-KV (Metal). Users
999        // can force off with `FERRUM_METAL_PAGED_KV=0`. The flag was
1000        // opt-in pre-0.7.2; flipping the default so default `ferrum
1001        // serve` matches the bench-quality numbers without requiring
1002        // env-var knowledge.
1003        let paged = runtime_env.paged_kv_enabled::<B>();
1004        const PAGED_BLOCK_SIZE: usize = 16;
1005
1006        // Phase 4 shared-pool sizing. The pool sees ALL concurrent
1007        // sequences; per-cache_id state just owns indices into it.
1008        // Default 32: covers c=16 burst with 2× headroom for the
1009        // fresh-cache-id-per-request pattern that bench/server harnesses
1010        // use. Pool memory is `max_seqs × max_blocks_per_seq` total
1011        // blocks — we lowered DEFAULT_KV_CAPACITY to 2048 so this 2× max_seqs
1012        // bump keeps the pool footprint identical to the pre-0.7.2 default.
1013        let max_seqs = runtime_env.paged_max_seqs;
1014        let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
1015        let total_pool_blocks = max_seqs * max_blocks_per_seq;
1016
1017        // Lazy-allocate the shared paged pools on the FIRST paged
1018        // ensure_kv call. Pools are big — for Llama-8B (8 kv_heads,
1019        // head_dim=128) at 16 seqs × 256 blocks × 16 slots = 65536 KV
1020        // slots: 65536 * 8 * 128 * 4 = 256 MB per layer × 32 layers
1021        // = 8 GB total. Sized this large only because `max_seqs=16`
1022        // is the default; lower it via env to shrink.
1023        if paged && self.paged_pools.is_none() {
1024            let mut pools = Vec::with_capacity(self.cfg.num_layers);
1025            for _ in 0..self.cfg.num_layers {
1026                let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
1027                pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
1028            }
1029            self.paged_pools = Some(pools);
1030            self.paged_block_alloc = Some(std::sync::Mutex::new(
1031                crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
1032            ));
1033        }
1034        // Phase 4b: ensure batched-dispatch scratch is allocated whenever
1035        // paged is on. Idempotent — re-init is a no-op if already
1036        // sized. Has to live outside the `paged_pools.is_none()` branch
1037        // because `ensure_scratch` may have replaced the scratch struct
1038        // since the pools were first allocated.
1039        if paged {
1040            self.scratch
1041                .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
1042            // Pin dims on the model so `ensure_scratch`'s realloc can
1043            // re-init paged_batch_* after wiping scratch.
1044            self.paged_dims = Some((max_seqs, max_blocks_per_seq));
1045        }
1046
1047        // Try pool first — reused buffers have stable device pointers,
1048        // so a captured decode graph can be replayed for this request too.
1049        // K::NAME selects which `LayerKvCache` variant to construct:
1050        // K-aware allocation: K::alloc_paged / K::alloc_contig pick the
1051        // right cache layout (FP16 → KvCache, INT8 → KvCacheQuant) per the
1052        // model's K marker. INT8 supports paged mode only — KvInt8::alloc_contig
1053        // panics, surfacing the misconfiguration here at first ensure_kv.
1054        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
1055            (0..self.cfg.num_layers)
1056                .map(|_| {
1057                    if paged {
1058                        K::alloc_paged(max_blocks_per_seq, PAGED_BLOCK_SIZE, nkv, hd)
1059                    } else {
1060                        K::alloc_contig(max, nkv, hd)
1061                    }
1062                })
1063                .collect()
1064        });
1065
1066        // Allocate physical blocks for THIS cache_id from the shared
1067        // pool. We allocate all `max_blocks_per_seq` upfront for
1068        // simplicity (matches contig's "pre-alloc to capacity"
1069        // semantics); a smarter Phase 4b can grow on demand to save
1070        // pool occupancy.
1071        if paged {
1072            let alloc_arc = self
1073                .paged_block_alloc
1074                .as_ref()
1075                .expect("paged_block_alloc must be initialised when paged=true");
1076            // Recover from a previously-poisoned mutex instead of panicking
1077            // (poison just means a prior holder panicked; the BlockAllocator
1078            // state is still intact since allocate_n is fail-safe).
1079            let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1080            let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
1081                Ok(idx) => idx,
1082                Err(e) => {
1083                    // Pool exhaustion is a back-pressure signal, not a crash.
1084                    // Drop the lock, return the cache to the free pool, and
1085                    // bail before inserting it into kv_caches. The downstream
1086                    // call will then fail with a clean per-request error
1087                    // ("ensure_kv must be called before ...") instead of
1088                    // dragging every other in-flight request down with it.
1089                    drop(alloc);
1090                    self.kv_free_pool.push(caches);
1091                    eprintln!(
1092                        "[ferrum] paged KV pool exhausted on ensure_kv for \
1093                         cache_id={cache_id:?}: {e}. Increase \
1094                         FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
1095                         throttle concurrent requests.",
1096                    );
1097                    return;
1098                }
1099            };
1100            // Write the block table to each layer's cache. All layers
1101            // share the same logical→physical mapping for this seq.
1102            // Also stash the host-side index list so release_kv can
1103            // return them to the allocator without a device readback.
1104            let mut padded = block_indices.clone();
1105            padded.resize(max_blocks_per_seq, 0);
1106            let mut ctx_tmp = B::new_context();
1107            for c in caches.iter_mut() {
1108                if let Some(bt) = K::block_table_mut(c) {
1109                    B::write_typed::<u32>(&mut ctx_tmp, bt, &padded);
1110                }
1111                *K::paged_block_indices_mut(c) = block_indices.clone();
1112            }
1113            B::sync(&mut ctx_tmp);
1114        }
1115
1116        // Reset logical length; buffers stay. No need to zero the memory —
1117        // the kv_cache_append writes new K/V in place, and attention only
1118        // reads up to `cache_len`.
1119        for c in caches.iter_mut() {
1120            K::set_len(c, 0);
1121            if let Some(cl) = K::context_lens_mut(c) {
1122                let mut ctx_tmp = B::new_context();
1123                B::write_typed::<u32>(&mut ctx_tmp, cl, &[0u32]);
1124                B::sync(&mut ctx_tmp);
1125            }
1126        }
1127        self.kv_caches.insert(cache_id.to_string(), caches);
1128    }
1129
1130    /// Run one transformer layer. Mutates `residual` in place.
1131    ///
1132    /// `pos_offset` is the absolute position of token 0 in this batch
1133    /// (decode: `pos`; prefill: 0). `tokens` is the batch size.
1134    #[allow(clippy::too_many_arguments)]
1135    pub(crate) fn forward_layer(
1136        &mut self,
1137        ctx: &mut B::Context,
1138        li: usize,
1139        cache_id: &str,
1140        residual: &mut B::Buffer,
1141        pos_offset: usize,
1142        tokens: usize,
1143    ) {
1144        let layer = &self.layers[li];
1145        let cfg = &self.cfg;
1146        let h = cfg.hidden_size;
1147        let nh = cfg.num_heads;
1148        let nkv = cfg.num_kv_heads;
1149        let hd = cfg.head_dim;
1150        let im = cfg.intermediate_size;
1151        let eps = cfg.rms_norm_eps;
1152        let q_dim = nh * hd;
1153        let kv_dim = nkv * hd;
1154
1155        // 1. Input RMSNorm
1156        let _t0 = if llama_family_runtime_env().decode_op_profile {
1157            B::sync(ctx);
1158            Some(std::time::Instant::now())
1159        } else {
1160            None
1161        };
1162        B::rms_norm(
1163            ctx,
1164            residual,
1165            &layer.input_ln_w,
1166            eps,
1167            &mut self.scratch.norm_out,
1168            tokens,
1169            h,
1170        );
1171        if let Some(t0) = _t0 {
1172            B::sync(ctx);
1173            NORM_TIME_US.fetch_add(
1174                t0.elapsed().as_micros() as u64,
1175                std::sync::atomic::Ordering::Relaxed,
1176            );
1177            NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1178        }
1179
1180        // 2. Fused QKV projection (Linear dispatches to Dense/GPTQ/AWQ/GGUF)
1181        let _t0 = if llama_family_runtime_env().decode_op_profile {
1182            B::sync(ctx);
1183            Some(std::time::Instant::now())
1184        } else {
1185            None
1186        };
1187        layer.qkv_proj.forward(
1188            ctx,
1189            &self.scratch.norm_out,
1190            &mut self.scratch.qkv_out,
1191            tokens,
1192        );
1193        if let Some(t0) = _t0 {
1194            B::sync(ctx);
1195            MATMUL_TIME_US.fetch_add(
1196                t0.elapsed().as_micros() as u64,
1197                std::sync::atomic::Ordering::Relaxed,
1198            );
1199            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1200        }
1201
1202        // 3-5. Fused split-QKV + QK-norm + RoPE + cache-write.
1203        //
1204        // Single Metal dispatch replaces the (split_qkv → 3× qk_norm_rope
1205        // → kv_cache_append_head_major) five-launch chain on the decode
1206        // hot path. Reads qkv_out once, writes Q to head-major scratch
1207        // and K/V straight into the pre-allocated KV cache slot at
1208        // `cache_len + tok`. Saves 4 dispatches per layer when the
1209        // backend implements the fused kernel; CPU and other backends
1210        // keep using the unfused chain via the Unsupported fallbacks.
1211        //
1212        // qk_mode: 1 = norm + half-split RoPE (Qwen3/Qwen HF);
1213        //          2 = half-split RoPE only;
1214        //          3 = interleaved RoPE only (GGUF LLaMA / llama.cpp layout).
1215        // V always passes apply_norm=0.
1216        let qk_mode: i32 = if cfg.has_qk_norm {
1217            1
1218        } else if cfg.rope_interleaved {
1219            3
1220        } else {
1221            2
1222        };
1223        let dummy = &layer.input_ln_w;
1224        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
1225        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
1226
1227        // Grab the per-layer KV cache up front so the deepest fusion can
1228        // write K/V straight into it.
1229        //
1230        // Paged mode: also need this layer's shared pool buffers
1231        // (self.paged_pools[li]). The pool is a separate field from
1232        // kv_caches, so we take a raw pointer to its (k, v) here while
1233        // we still hold &mut self, then deref via unsafe inside the
1234        // paged dispatch below. Safety: paged_pools is allocated once
1235        // and never resized; we don't touch self.paged_pools while the
1236        // pointer is in use.
1237        let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
1238            if let Some(pools) = self.paged_pools.as_mut() {
1239                let pool = &mut pools[li];
1240                Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
1241            } else {
1242                None
1243            };
1244        let caches = self
1245            .kv_caches
1246            .get_mut(cache_id)
1247            .expect("ensure_kv must be called before forward_layer");
1248        // Read shared metadata (variant-agnostic) once. The K-aware
1249        // attn section below re-borrows the right enum variant.
1250        let cache_len_before = K::len(&caches[li]);
1251        let cache_capacity = K::capacity(&caches[li]);
1252        let cache_block_size = K::block_size(&caches[li]);
1253
1254        // Defense in depth: refuse to write past the KV buffer. The
1255        // graceful path is the caller pre-checking via `kv_capacity()`
1256        // and either compacting or refusing the request; this panic only
1257        // fires when that contract is broken (and silent overflow would
1258        // otherwise corrupt the cache + adjacent allocations).
1259        if cache_len_before + tokens > cache_capacity {
1260            panic!(
1261                "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
1262                cache_len_before + tokens
1263            );
1264        }
1265
1266        // Paged path: K::paged_write fuses split_qkv_norm_rope + cache append
1267        // (FP16: into_paged_cache; INT8: split_qkv_norm_rope + int8_kv_append_paged).
1268        // K::paged_decode_attention reads from layer-local INT8 buffers or the
1269        // shared FP16 pool depending on K. Then K-agnostic post-attn tail.
1270        if cache_block_size > 0 {
1271            let (pool_k_ptr, pool_v_ptr) =
1272                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1273            // SAFETY: paged_pools is allocated once and never resized; the
1274            // raw pointers don't outlive this method scope.
1275            let pool_k = unsafe { &mut *pool_k_ptr };
1276            let pool_v = unsafe { &mut *pool_v_ptr };
1277
1278            K::paged_write(
1279                ctx,
1280                &mut caches[li],
1281                &self.scratch.qkv_out,
1282                q_norm_w,
1283                k_norm_w,
1284                &self.rope.cos,
1285                &self.rope.sin,
1286                &mut self.scratch.q_head_major,
1287                &mut self.scratch.k_head_major,
1288                &mut self.scratch.v_head_major,
1289                pool_k,
1290                pool_v,
1291                tokens,
1292                nh,
1293                nkv,
1294                hd,
1295                pos_offset,
1296                eps,
1297                qk_mode,
1298            )
1299            .expect("K::paged_write");
1300
1301            let new_len = cache_len_before + tokens;
1302            K::set_len(&mut caches[li], new_len);
1303
1304            let pool_k_imm = unsafe { &*pool_k_ptr };
1305            let pool_v_imm = unsafe { &*pool_v_ptr };
1306            K::paged_decode_attention(
1307                ctx,
1308                &mut caches[li],
1309                &self.scratch.q_head_major,
1310                pool_k_imm,
1311                pool_v_imm,
1312                &mut self.scratch.attn_head_major_out,
1313                nh,
1314                nkv,
1315                hd,
1316                new_len,
1317                tokens,
1318            )
1319            .expect("K::paged_decode_attention");
1320
1321            return self.forward_layer_post_attn(ctx, li, residual, tokens);
1322        }
1323
1324        // Non-paged (contig) path. INT8 path doesn't reach here:
1325        // KvInt8::alloc_contig panics in ensure_kv.
1326        let _qkr_t0 = if llama_family_runtime_env().decode_op_profile {
1327            B::sync(ctx);
1328            Some(std::time::Instant::now())
1329        } else {
1330            None
1331        };
1332        K::contig_write(
1333            ctx,
1334            &mut caches[li],
1335            &self.scratch.qkv_out,
1336            q_norm_w,
1337            k_norm_w,
1338            &self.rope.cos,
1339            &self.rope.sin,
1340            &mut self.scratch.q_head_major,
1341            &mut self.scratch.k_head_major,
1342            &mut self.scratch.v_head_major,
1343            &mut self.scratch.q_buf,
1344            &mut self.scratch.k_buf,
1345            &mut self.scratch.v_buf,
1346            tokens,
1347            nh,
1348            nkv,
1349            hd,
1350            pos_offset,
1351            eps,
1352            qk_mode,
1353        )
1354        .expect("K::contig_write");
1355        if let Some(t0) = _qkr_t0 {
1356            B::sync(ctx);
1357            QKR_TIME_US.fetch_add(
1358                t0.elapsed().as_micros() as u64,
1359                std::sync::atomic::Ordering::Relaxed,
1360            );
1361            QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1362        }
1363        let new_len = cache_len_before + tokens;
1364        K::set_len(&mut caches[li], new_len);
1365        let kv_stride = cache_capacity;
1366
1367        let _attn_t0 = if llama_family_runtime_env().decode_op_profile {
1368            B::sync(ctx);
1369            Some(std::time::Instant::now())
1370        } else {
1371            None
1372        };
1373        let attn_cfg = ferrum_kernels::backend::AttnConfig {
1374            num_heads: nh,
1375            num_kv_heads: nkv,
1376            head_dim: hd,
1377            causal: true,
1378            scale: 1.0 / (hd as f32).sqrt(),
1379            kv_seq_stride: kv_stride,
1380            sliding_window: cfg.sliding_window,
1381        };
1382        K::contig_decode_attention(
1383            ctx,
1384            &caches[li],
1385            &self.scratch.q_head_major,
1386            &mut self.scratch.attn_head_major_out,
1387            attn_cfg,
1388            tokens,
1389            pos_offset,
1390        )
1391        .expect("K::contig_decode_attention");
1392        let _ = q_dim;
1393        let _ = kv_dim;
1394        let _ = dummy;
1395        if let Some(t0) = _attn_t0 {
1396            B::sync(ctx);
1397            ATTN_TIME_US.fetch_add(
1398                t0.elapsed().as_micros() as u64,
1399                std::sync::atomic::Ordering::Relaxed,
1400            );
1401            ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1402        }
1403
1404        self.forward_layer_post_attn(ctx, li, residual, tokens);
1405    }
1406
1407    /// Post-attention tail of `forward_layer`: untranspose Q (if needed),
1408    /// O-proj, fused residual+post_norm, gate_up_proj, SwiGLU, down_proj,
1409    /// final residual add. K-agnostic — reads `self.scratch.attn_head_major_out`
1410    /// which both the FP16 and INT8 attn paths populate.
1411    pub(crate) fn forward_layer_post_attn(
1412        &mut self,
1413        ctx: &mut B::Context,
1414        li: usize,
1415        residual: &mut B::Buffer,
1416        tokens: usize,
1417    ) {
1418        let layer = &self.layers[li];
1419        let cfg = &self.cfg;
1420        let h = cfg.hidden_size;
1421        let nh = cfg.num_heads;
1422        let hd = cfg.head_dim;
1423        let im = cfg.intermediate_size;
1424        let eps = cfg.rms_norm_eps;
1425
1426        // 7. Untranspose head-major → token-major for O-proj input.
1427        let attn_token_major = if tokens == 1 {
1428            &self.scratch.attn_head_major_out
1429        } else {
1430            B::transpose_head_to_token(
1431                ctx,
1432                &self.scratch.attn_head_major_out,
1433                &mut self.scratch.attn_flat,
1434                tokens,
1435                nh,
1436                hd,
1437            );
1438            &self.scratch.attn_flat
1439        };
1440
1441        // 8. O projection.
1442        layer
1443            .o_proj
1444            .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1445
1446        // 9. Fused residual-add + post-attention RMSNorm.
1447        B::fused_add_rms_norm(
1448            ctx,
1449            residual,
1450            &self.scratch.o_proj_out,
1451            &layer.post_ln_w,
1452            eps,
1453            &mut self.scratch.norm_out,
1454            tokens,
1455            h,
1456        );
1457
1458        // 10. Fused gate+up projection.
1459        layer.gate_up_proj.forward(
1460            ctx,
1461            &self.scratch.norm_out,
1462            &mut self.scratch.gate_up_out,
1463            tokens,
1464        );
1465
1466        // 11. SwiGLU: silu(gate) * up.
1467        B::fused_silu_mul_split(
1468            ctx,
1469            &self.scratch.gate_up_out,
1470            &mut self.scratch.silu_out,
1471            tokens,
1472            im,
1473        );
1474
1475        // 12. Down projection.
1476        layer.down_proj.forward(
1477            ctx,
1478            &self.scratch.silu_out,
1479            &mut self.scratch.mlp_out,
1480            tokens,
1481        );
1482
1483        // 13. Final residual add.
1484        B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1485    }
1486
1487    /// Multi-position decode-verify: run one forward pass over `tokens`
1488    /// starting at the cache's current end position, write their K/V
1489    /// into the KV cache, and return logits for ALL `tokens.len()`
1490    /// positions as a flat `Vec<f32>` of length `seq_len * vocab_size`.
1491    ///
1492    /// Used by speculative decoding: target receives
1493    /// `[last_token, draft_0, ..., draft_{N-1}]` (N+1 inputs) and produces
1494    /// N+1 logit rows in a single forward instead of N+1 sequential
1495    /// decode() calls. Positions are implicit — the model looks up
1496    /// `pos_offset = cache.len` the same way prefill_internal does, so
1497    /// chunked prefill semantics carry over for free.
1498    pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1499        let seq_len = tokens.len();
1500        assert!(seq_len > 0, "forward_verify called with empty tokens");
1501        self.ensure_scratch(seq_len);
1502        self.ensure_kv(cache_id);
1503
1504        let h = self.cfg.hidden_size;
1505        let vocab = self.cfg.vocab_size;
1506
1507        let pos_offset = self
1508            .kv_caches
1509            .get(cache_id)
1510            .and_then(|layers| layers.first())
1511            .map(|c| K::len(c))
1512            .unwrap_or(0);
1513
1514        let mut ctx = B::new_context();
1515        let mut residual = self
1516            .scratch
1517            .residual
1518            .take()
1519            .expect("scratch residual missing (previous call didn't restore)");
1520
1521        let embed = self
1522            .embed
1523            .as_ref()
1524            .expect("forward_verify called on backbone-only model (no embed)");
1525        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1526
1527        for li in 0..self.cfg.num_layers {
1528            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1529        }
1530
1531        // RMSNorm on ALL seq_len positions (prefill_internal only norms
1532        // the last one; verify needs the full grid).
1533        B::rms_norm(
1534            &mut ctx,
1535            &residual,
1536            &self.final_norm_w,
1537            self.cfg.rms_norm_eps,
1538            &mut self.scratch.norm_out,
1539            seq_len,
1540            h,
1541        );
1542
1543        // LM head applied to all positions → `seq_len * vocab` logits.
1544        // Reuses the existing `batch_logits` scratch (sized max_tokens *
1545        // vocab) so no extra allocation.
1546        let lm_head = self
1547            .lm_head
1548            .as_ref()
1549            .expect("forward_verify called on backbone-only model (no lm_head)");
1550        lm_head.forward(
1551            &mut ctx,
1552            &self.scratch.norm_out,
1553            &mut self.scratch.batch_logits,
1554            seq_len,
1555        );
1556
1557        B::sync(&mut ctx);
1558        self.scratch.residual = Some(residual);
1559        B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1560    }
1561
1562    /// Prefill: process `tokens` prompt tokens in a single batch, return
1563    /// `[vocab_size]` logits for the last position.
1564    ///
1565    /// Supports incremental prefill: if the KV cache for `cache_id` already
1566    /// contains earlier tokens, the new chunk's positions are computed as
1567    /// `[kv_len, kv_len + tokens.len())` so RoPE and causal masking stay
1568    /// aligned. Used by the engine's chunked-prefill path.
1569    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1570        let seq_len = tokens.len();
1571        assert!(seq_len > 0, "prefill called with empty token list");
1572        self.ensure_scratch(seq_len);
1573        self.ensure_kv(cache_id);
1574
1575        // Starting position for this chunk — 0 for a fresh prefill, kv_len
1576        // for the second+ chunk of a split prefill.
1577        let pos_offset = self
1578            .kv_caches
1579            .get(cache_id)
1580            .and_then(|layers| layers.first())
1581            .map(|c| K::len(c))
1582            .unwrap_or(0);
1583
1584        let h = self.cfg.hidden_size;
1585        let vocab = self.cfg.vocab_size;
1586        let mut ctx = B::new_context();
1587
1588        // Move `residual` out of `scratch` to work around the borrow checker:
1589        // `forward_layer` re-borrows `&mut self` to reach `self.layers` /
1590        // `self.kv_caches`, which would conflict with an outstanding
1591        // `&mut self.scratch.residual`. Use Option::take to move it out
1592        // (no placeholder alloc → no transient cuMemFreeAsync that could
1593        // corrupt stream pool state after graph ops on Blackwell).
1594        let mut residual = self
1595            .scratch
1596            .residual
1597            .take()
1598            .expect("scratch residual missing (previous call didn't restore)");
1599        let embed = self
1600            .embed
1601            .as_ref()
1602            .expect("prefill_internal called on backbone-only model (no embed)");
1603        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1604
1605        let prefill_profile = llama_family_runtime_env().prefill_op_profile;
1606        let prefill_t0 = if prefill_profile {
1607            B::sync(&mut ctx);
1608            Some(std::time::Instant::now())
1609        } else {
1610            None
1611        };
1612
1613        for li in 0..self.cfg.num_layers {
1614            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1615        }
1616
1617        if let Some(t0) = prefill_t0 {
1618            B::sync(&mut ctx);
1619            let total_us = t0.elapsed().as_micros() as u64;
1620            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1621            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1622            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1623            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1624            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1625            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1626            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1627            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1628            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1629            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1630            eprintln!(
1631                "[prefill-profile] tokens={} layers total={} ms",
1632                seq_len,
1633                total_us / 1000
1634            );
1635            let bucket = |label: &str, n: u64, us: u64| {
1636                if n > 0 {
1637                    eprintln!(
1638                        "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
1639                        n,
1640                        us / 1000,
1641                        us / n
1642                    );
1643                }
1644            };
1645            bucket("flash_attn", attn_n, attn_us);
1646            bucket("qk_norm_rope", qkr_n, qkr_us);
1647            bucket("matmuls", mm_n, mm_us);
1648            bucket("norms", norm_n, norm_us);
1649            bucket("other", other_n, other_us);
1650        }
1651
1652        // Take the last token's hidden state: residual[(seq_len-1)*h .. seq_len*h]
1653        B::copy_slice(
1654            &mut ctx,
1655            &residual,
1656            (seq_len - 1) * h,
1657            &mut self.scratch.last_hidden,
1658            0,
1659            h,
1660        );
1661
1662        // Final RMSNorm on the last hidden.
1663        B::rms_norm(
1664            &mut ctx,
1665            &self.scratch.last_hidden,
1666            &self.final_norm_w,
1667            self.cfg.rms_norm_eps,
1668            &mut self.scratch.last_normed,
1669            1,
1670            h,
1671        );
1672
1673        // LM head (m=1 — triggers GEMV on MetalBackend).
1674        let lm_head = self
1675            .lm_head
1676            .as_ref()
1677            .expect("prefill_internal called on backbone-only model (no lm_head)");
1678        lm_head.forward(
1679            &mut ctx,
1680            &self.scratch.last_normed,
1681            &mut self.scratch.logits,
1682            1,
1683        );
1684
1685        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
1686        // buffer's CPU pointer without flushing the command buffer, so the
1687        // GPU must complete all pending work first or we read stale/random
1688        // data. CUDA's to_vec does an internal stream.synchronize, making
1689        // the call redundant there (~50µs/step cost), but correctness on
1690        // Metal requires the explicit flush here.
1691        B::sync(&mut ctx);
1692
1693        // Restore residual into scratch for reuse on the next call.
1694        self.scratch.residual = Some(residual);
1695
1696        B::to_vec(&self.scratch.logits, vocab)
1697    }
1698
1699    /// Decode: process 1 token at position `pos`, return `[vocab_size]` logits.
1700    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1701        self.ensure_scratch(1);
1702        self.ensure_kv(cache_id);
1703
1704        let h = self.cfg.hidden_size;
1705        let vocab = self.cfg.vocab_size;
1706
1707        // Context creation is cheap (CUDA reuses the process-global stream).
1708        // The captured graph lives in a process-global slot, not on ctx.
1709        let mut ctx = B::new_context();
1710
1711        // Graph capture is opt-in via FERRUM_CUDA_GRAPH=1. Replay is currently
1712        // single-request-only on Blackwell + CUDA 12.8 (see
1713        // docs/phase-e-cuda-status.md). In pure eager mode, we skip the
1714        // per-step device-state memcpy_htod trio entirely.
1715        const GRAPH_WARMUP: usize = 3;
1716        let graph_enabled = llama_family_runtime_env().cuda_graph;
1717
1718        if graph_enabled {
1719            // Refresh device-side dynamic state (token/pos/kv_len) before
1720            // replay — captured graph reads these from device buffers.
1721            B::set_decode_state(&mut ctx, token, pos);
1722
1723            // Fast path: graph replay (if available). Single-item path
1724            // uses key=SINGLE_ITEM_GRAPH_KEY (0) — separate from the
1725            // batched path's m_padded keys.
1726            match B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY) {
1727                Ok(true) => {
1728                    B::sync(&mut ctx);
1729                    return B::to_vec(&self.scratch.logits, vocab);
1730                }
1731                Ok(false) => { /* no graph yet, fall through to eager */ }
1732                Err(_) => { /* backend error or unsupported, eager */ }
1733            }
1734        }
1735
1736        let should_capture =
1737            graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
1738
1739        if should_capture {
1740            B::set_dev_state_mode(&mut ctx, true);
1741            if B::begin_graph_capture(&mut ctx).is_err() {
1742                self.graph_capture_failed = true;
1743                B::set_dev_state_mode(&mut ctx, false);
1744            }
1745        }
1746
1747        // Eager forward (records into graph if capture is active).
1748        // mem::replace needs a placeholder. B::alloc(0) was our choice but
1749        // cuMemAllocFromPoolAsync(stream, 0) can return CUDA_ERROR_INVALID_VALUE
1750        // on Blackwell after graph replay corrupts the pool state. Size-1 is
1751        // always valid and costs 2 bytes of transient VRAM per decode step.
1752        let mut residual = self
1753            .scratch
1754            .residual
1755            .take()
1756            .expect("scratch residual missing (previous call didn't restore)");
1757        let embed = self
1758            .embed
1759            .as_ref()
1760            .expect("decode_internal called on backbone-only model (no embed)");
1761        B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
1762
1763        // Per-layer wall-time profile (env-gated, off by default — adds
1764        // a B::sync between layers which serializes the pipeline). Helps
1765        // localise non-matmul bottlenecks during perf work.
1766        let layer_profile = llama_family_runtime_env().decode_layer_profile;
1767        let mut layer_times = if layer_profile {
1768            Some(Vec::with_capacity(self.cfg.num_layers))
1769        } else {
1770            None
1771        };
1772
1773        for li in 0..self.cfg.num_layers {
1774            if layer_profile {
1775                let t0 = std::time::Instant::now();
1776                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1777                B::sync(&mut ctx);
1778                let elapsed_us = t0.elapsed().as_micros() as u64;
1779                if let Some(v) = layer_times.as_mut() {
1780                    v.push(elapsed_us);
1781                }
1782            } else {
1783                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1784            }
1785        }
1786        if let Some(times) = layer_times.take() {
1787            let sum: u64 = times.iter().sum();
1788            let avg = sum / times.len() as u64;
1789            let mn = *times.iter().min().unwrap_or(&0);
1790            let mx = *times.iter().max().unwrap_or(&0);
1791            eprintln!(
1792                "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
1793                times.len(),
1794                sum / 1000,
1795                avg,
1796                mn,
1797                mx
1798            );
1799            for (i, t) in times.iter().enumerate() {
1800                eprint!("L{i}={}ms ", t / 1000);
1801                if (i + 1) % 6 == 0 {
1802                    eprintln!();
1803                }
1804            }
1805            eprintln!();
1806            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1807            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1808            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1809            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1810            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1811            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1812            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1813            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1814            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
1815            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
1816            eprintln!(
1817                "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
1818                attn_n,
1819                attn_us / 1000,
1820                if attn_n > 0 { attn_us / attn_n } else { 0 }
1821            );
1822            eprintln!(
1823                "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
1824                qkr_n,
1825                qkr_us / 1000,
1826                if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
1827            );
1828            eprintln!(
1829                "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
1830                mm_n,
1831                mm_us / 1000,
1832                if mm_n > 0 { mm_us / mm_n } else { 0 }
1833            );
1834            eprintln!(
1835                "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
1836                norm_n,
1837                norm_us / 1000,
1838                if norm_n > 0 { norm_us / norm_n } else { 0 }
1839            );
1840            eprintln!(
1841                "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
1842                other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
1843            );
1844        }
1845
1846        B::rms_norm(
1847            &mut ctx,
1848            &residual,
1849            &self.final_norm_w,
1850            self.cfg.rms_norm_eps,
1851            &mut self.scratch.last_normed,
1852            1,
1853            h,
1854        );
1855
1856        let lm_head = self
1857            .lm_head
1858            .as_ref()
1859            .expect("decode_internal called on backbone-only model (no lm_head)");
1860        lm_head.forward(
1861            &mut ctx,
1862            &self.scratch.last_normed,
1863            &mut self.scratch.logits,
1864            1,
1865        );
1866
1867        if should_capture && !self.graph_capture_failed {
1868            if B::end_graph_capture(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
1869                self.graph_capture_failed = true;
1870            } else {
1871                // Stream capture mode RECORDS ops into the graph without
1872                // executing them. scratch.logits still holds the previous
1873                // step's value. Replay the just-captured graph once to
1874                // actually execute and produce this step's logits. Without
1875                // this, the capture step's to_vec returns stale logits,
1876                // yielding a 1-token offset in the generated sequence.
1877                if B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
1878                    self.graph_capture_failed = true;
1879                }
1880            }
1881            B::set_dev_state_mode(&mut ctx, false);
1882        } else {
1883            self.graph_warmup += 1;
1884        }
1885
1886        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
1887        // buffer's CPU pointer without flushing the command buffer, so the
1888        // GPU must complete all pending work first or we read stale/random
1889        // data. CUDA's to_vec does an internal stream.synchronize, making
1890        // the call redundant there (~50µs/step cost), but correctness on
1891        // Metal requires the explicit flush here.
1892        B::sync(&mut ctx);
1893        self.scratch.residual = Some(residual);
1894
1895        B::to_vec(&self.scratch.logits, vocab)
1896    }
1897
1898    /// Prefill with pre-computed embeddings instead of token IDs.
1899    ///
1900    /// Used by models that embed inputs outside the LLM (e.g. Qwen3-TTS
1901    /// mixes text-embedding + codec-embedding before feeding the LM).
1902    /// Skips `final_norm` + `lm_head`; returns the last position's pre-norm
1903    /// hidden state. Caller applies its own output head.
1904    ///
1905    /// `embeds` is row-major `[seq_len * hidden_size]`, f32.
1906    pub fn prefill_from_embeds(
1907        &mut self,
1908        cache_id: &str,
1909        embeds: &[f32],
1910        seq_len: usize,
1911    ) -> Vec<f32> {
1912        let h = self.cfg.hidden_size;
1913        assert_eq!(
1914            embeds.len(),
1915            seq_len * h,
1916            "embeds length {} != seq_len * hidden_size {}",
1917            embeds.len(),
1918            seq_len * h
1919        );
1920        assert!(seq_len > 0, "prefill_from_embeds called with zero length");
1921
1922        self.ensure_scratch(seq_len);
1923        self.ensure_kv(cache_id);
1924
1925        let mut ctx = B::new_context();
1926        let mut residual = self
1927            .scratch
1928            .residual
1929            .take()
1930            .expect("scratch residual missing (previous call didn't restore)");
1931
1932        // Upload embeds → residual[0 .. seq_len*h].
1933        let embed_buf = B::from_slice(embeds);
1934        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1935
1936        for li in 0..self.cfg.num_layers {
1937            self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1938        }
1939
1940        B::copy_slice(
1941            &mut ctx,
1942            &residual,
1943            (seq_len - 1) * h,
1944            &mut self.scratch.last_hidden,
1945            0,
1946            h,
1947        );
1948        B::sync(&mut ctx);
1949        self.scratch.residual = Some(residual);
1950        B::to_vec(&self.scratch.last_hidden, h)
1951    }
1952
1953    /// Decode with a single pre-computed embedding (shape `[hidden]`).
1954    /// Returns the pre-norm hidden state for the position `pos`. Caller
1955    /// applies final norm + its own output head.
1956    pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1957        let h = self.cfg.hidden_size;
1958        assert_eq!(
1959            embed.len(),
1960            h,
1961            "embed length {} != hidden_size {}",
1962            embed.len(),
1963            h
1964        );
1965
1966        self.ensure_scratch(1);
1967        self.ensure_kv(cache_id);
1968
1969        let mut ctx = B::new_context();
1970        let mut residual = self
1971            .scratch
1972            .residual
1973            .take()
1974            .expect("scratch residual missing (previous call didn't restore)");
1975
1976        let embed_buf = B::from_slice(embed);
1977        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1978
1979        for li in 0..self.cfg.num_layers {
1980            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1981        }
1982
1983        B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1984        B::sync(&mut ctx);
1985        self.scratch.residual = Some(residual);
1986        B::to_vec(&self.scratch.last_hidden, h)
1987    }
1988
1989    /// Variant of `prefill_from_embeds` that applies `final_norm` to every
1990    /// position and returns the whole `[seq_len * hidden_size]` vector.
1991    /// Accepts `pos_offset` so callers can continue an existing sequence
1992    /// (e.g. Qwen3-TTS voice-clone: one prefill for the role prefix, a
1993    /// follow-up prefill for the reference-audio ICL block, then
1994    /// autoregressive decoding — all against the same KV cache).
1995    ///
1996    /// Used by TTS where `forward_step` in the candle-based wrapper is
1997    /// expected to return **post-norm all-positions** hidden state so
1998    /// `codec_head` can be applied on candle side.
1999    pub fn prefill_all_post_norm(
2000        &mut self,
2001        cache_id: &str,
2002        embeds: &[f32],
2003        seq_len: usize,
2004        pos_offset: usize,
2005    ) -> Vec<f32> {
2006        let h = self.cfg.hidden_size;
2007        assert_eq!(
2008            embeds.len(),
2009            seq_len * h,
2010            "embeds length {} != seq_len * hidden_size {}",
2011            embeds.len(),
2012            seq_len * h
2013        );
2014        assert!(seq_len > 0);
2015
2016        self.ensure_scratch(seq_len);
2017        self.ensure_kv(cache_id);
2018
2019        let mut ctx = B::new_context();
2020        let mut residual = self
2021            .scratch
2022            .residual
2023            .take()
2024            .expect("scratch residual missing (previous call didn't restore)");
2025
2026        let embed_buf = B::from_slice(embeds);
2027        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2028
2029        for li in 0..self.cfg.num_layers {
2030            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2031        }
2032
2033        // Apply final_norm over all seq_len positions → scratch.norm_out.
2034        B::rms_norm(
2035            &mut ctx,
2036            &residual,
2037            &self.final_norm_w,
2038            self.cfg.rms_norm_eps,
2039            &mut self.scratch.norm_out,
2040            seq_len,
2041            h,
2042        );
2043        B::sync(&mut ctx);
2044        self.scratch.residual = Some(residual);
2045        B::to_vec(&self.scratch.norm_out, seq_len * h)
2046    }
2047
2048    /// Decode-side companion to `prefill_all_post_norm`. Runs a single-token
2049    /// decode step at `pos`, applies `final_norm`, and returns the post-norm
2050    /// hidden state `[hidden_size]`.
2051    pub fn decode_post_norm_from_embed(
2052        &mut self,
2053        cache_id: &str,
2054        embed: &[f32],
2055        pos: u32,
2056    ) -> Vec<f32> {
2057        let h = self.cfg.hidden_size;
2058        assert_eq!(embed.len(), h);
2059
2060        self.ensure_scratch(1);
2061        self.ensure_kv(cache_id);
2062
2063        let mut ctx = B::new_context();
2064        let mut residual = self
2065            .scratch
2066            .residual
2067            .take()
2068            .expect("scratch residual missing (previous call didn't restore)");
2069
2070        let embed_buf = B::from_slice(embed);
2071        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2072
2073        for li in 0..self.cfg.num_layers {
2074            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2075        }
2076
2077        B::rms_norm(
2078            &mut ctx,
2079            &residual,
2080            &self.final_norm_w,
2081            self.cfg.rms_norm_eps,
2082            &mut self.scratch.last_normed,
2083            1,
2084            h,
2085        );
2086        B::sync(&mut ctx);
2087        self.scratch.residual = Some(residual);
2088        B::to_vec(&self.scratch.last_normed, h)
2089    }
2090}
2091
2092// FP16 DecoderOnlyLLM impl — full path with batched + unified-forward overrides.
2093impl<B: MoeLlmBackend> DecoderOnlyLLM for LlamaFamilyModel<B, KvFp16> {
2094    fn config(&self) -> &LlmRuntimeConfig {
2095        &self.runtime_cfg
2096    }
2097
2098    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2099        self.ensure_scratch(max_tokens);
2100        self.ensure_kv(cache_id);
2101
2102        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2103        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2104        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2105            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2106                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2107                if let Some(c0) = caches.first() {
2108                    if !c0.paged_block_indices.is_empty() {
2109                        alloc.free(&c0.paged_block_indices);
2110                    }
2111                }
2112                for c in caches.iter_mut() {
2113                    c.paged_block_indices.clear();
2114                }
2115            }
2116            self.kv_free_pool.push(caches);
2117        }
2118    }
2119
2120    fn kv_capacity(&self) -> usize {
2121        llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2122    }
2123
2124    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2125        self.prefill_internal(cache_id, tokens)
2126    }
2127
2128    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2129        self.decode_internal(cache_id, token, pos)
2130    }
2131
2132    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2133        self.decode_batch_internal(batch)
2134    }
2135
2136    fn unified_forward(
2137        &mut self,
2138        items: &[(String, Vec<u32>, usize, bool)],
2139    ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
2140        if items.is_empty() {
2141            return Ok(Vec::new());
2142        }
2143        if !B::supports_varlen_qkv() {
2144            return Err(ferrum_types::FerrumError::unsupported(
2145                "LlamaFamilyModel::unified_forward: backend lacks varlen \
2146                 QKV kernels. Engine will fall back to per-item dispatch.",
2147            ));
2148        }
2149        self.ensure_kv(&items[0].0);
2150        if self.paged_pools.is_none() {
2151            return Err(ferrum_types::FerrumError::unsupported(
2152                "LlamaFamilyModel::unified_forward: paged KV required; \
2153                 enable via FERRUM_METAL_PAGED_KV=1 (cross-backend env). \
2154                 Engine will fall back to per-item dispatch.",
2155            ));
2156        }
2157        for (cid, _, _, _) in items {
2158            self.ensure_kv(cid);
2159            if !self.kv_caches.contains_key(cid) {
2160                return Err(ferrum_types::FerrumError::resource_exhausted(format!(
2161                    "paged KV pool exhausted for cache_id={cid:?}; back off"
2162                )));
2163            }
2164        }
2165        Ok(self.unified_forward_internal(items))
2166    }
2167
2168    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2169        LlamaFamilyModel::<B, KvFp16>::forward_verify(self, cache_id, tokens)
2170    }
2171
2172    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2173        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2174            for c in caches.iter_mut() {
2175                if new_len < c.len {
2176                    c.len = new_len;
2177                }
2178            }
2179        }
2180        let mut ctx = B::new_context();
2181        B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2182        self.graph_warmup = 0;
2183        self.graph_capture_failed = false;
2184    }
2185
2186    fn release(&mut self, cache_id: &str) {
2187        let mut ctx = B::new_context();
2188        B::sync(&mut ctx);
2189        self.graph_warmup = 0;
2190        self.graph_capture_failed = false;
2191        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2192            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2193                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2194                if let Some(c0) = caches.first() {
2195                    if !c0.paged_block_indices.is_empty() {
2196                        alloc.free(&c0.paged_block_indices);
2197                    }
2198                }
2199                for c in caches.iter_mut() {
2200                    c.paged_block_indices.clear();
2201                }
2202            }
2203            self.kv_free_pool.push(caches);
2204        }
2205    }
2206
2207    fn reset(&mut self) {
2208        let mut ctx = B::new_context();
2209        B::sync(&mut ctx);
2210        B::reset_all_graphs(&mut ctx);
2211        B::sync(&mut ctx);
2212        self.graph_warmup = 0;
2213        self.graph_capture_failed = false;
2214        self.batched_graph_keys_seen.clear();
2215        self.batched_graph_warmup = 0;
2216        self.batched_graph_failed = false;
2217        self.kv_caches.clear();
2218        self.kv_free_pool.clear();
2219    }
2220}
2221
2222// INT8 DecoderOnlyLLM impl — minimal: no batched / unified-forward overrides
2223// (default trait impl falls back to per-item decode). PR D will add INT8
2224// batched paths once the kernels stabilize.
2225impl<B: MoeLlmBackend + BackendInt8KvOps> DecoderOnlyLLM for LlamaFamilyModel<B, KvInt8> {
2226    fn config(&self) -> &LlmRuntimeConfig {
2227        &self.runtime_cfg
2228    }
2229
2230    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2231        self.ensure_scratch(max_tokens);
2232        self.ensure_kv(cache_id);
2233
2234        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2235        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2236        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2237            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2238                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2239                if let Some(c0) = caches.first() {
2240                    if !c0.paged_block_indices.is_empty() {
2241                        alloc.free(&c0.paged_block_indices);
2242                    }
2243                }
2244                for c in caches.iter_mut() {
2245                    c.paged_block_indices.clear();
2246                }
2247            }
2248            self.kv_free_pool.push(caches);
2249        }
2250    }
2251
2252    fn kv_capacity(&self) -> usize {
2253        llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2254    }
2255
2256    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2257        self.prefill_internal(cache_id, tokens)
2258    }
2259
2260    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2261        self.decode_internal(cache_id, token, pos)
2262    }
2263
2264    // decode_batch + unified_forward use trait defaults (per-item fallback).
2265
2266    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2267        LlamaFamilyModel::<B, KvInt8>::forward_verify(self, cache_id, tokens)
2268    }
2269
2270    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2271        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2272            for c in caches.iter_mut() {
2273                if new_len < c.len {
2274                    c.len = new_len;
2275                }
2276            }
2277        }
2278        let mut ctx = B::new_context();
2279        B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2280        self.graph_warmup = 0;
2281        self.graph_capture_failed = false;
2282    }
2283
2284    fn release(&mut self, cache_id: &str) {
2285        let mut ctx = B::new_context();
2286        B::sync(&mut ctx);
2287        self.graph_warmup = 0;
2288        self.graph_capture_failed = false;
2289        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2290            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2291                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2292                if let Some(c0) = caches.first() {
2293                    if !c0.paged_block_indices.is_empty() {
2294                        alloc.free(&c0.paged_block_indices);
2295                    }
2296                }
2297                for c in caches.iter_mut() {
2298                    c.paged_block_indices.clear();
2299                }
2300            }
2301            self.kv_free_pool.push(caches);
2302        }
2303    }
2304
2305    fn reset(&mut self) {
2306        let mut ctx = B::new_context();
2307        B::sync(&mut ctx);
2308        B::reset_all_graphs(&mut ctx);
2309        B::sync(&mut ctx);
2310        self.graph_warmup = 0;
2311        self.graph_capture_failed = false;
2312        self.kv_caches.clear();
2313        self.kv_free_pool.clear();
2314    }
2315}
2316
2317fn build_rope_cache<B: QuantLlmBackend + BackendMoeFused>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2318    let hd = cfg.head_dim;
2319    let half = hd / 2;
2320    let max = cfg.max_seq_len;
2321    let mut cos = vec![0.0f32; max * half];
2322    let mut sin = vec![0.0f32; max * half];
2323    for pos in 0..max {
2324        for i in 0..half {
2325            let freq = rope_freq(cfg, i);
2326            let angle = pos as f64 * freq;
2327            cos[pos * half + i] = angle.cos() as f32;
2328            sin[pos * half + i] = angle.sin() as f32;
2329        }
2330    }
2331    RopeCache {
2332        cos: B::from_slice(&cos),
2333        sin: B::from_slice(&sin),
2334    }
2335}
2336
2337fn rope_freq(cfg: &LlamaFamilyConfig, pair_idx: usize) -> f64 {
2338    let base_freq = 1.0f64
2339        / cfg
2340            .rope_theta
2341            .powf((2 * pair_idx) as f64 / cfg.head_dim as f64);
2342    match &cfg.rope_scaling {
2343        Some(RopeScalingConfig::Llama3 {
2344            factor,
2345            low_freq_factor,
2346            high_freq_factor,
2347            original_max_position_embeddings,
2348        }) => scale_llama3_rope_freq(
2349            base_freq,
2350            *factor,
2351            *low_freq_factor,
2352            *high_freq_factor,
2353            *original_max_position_embeddings,
2354        ),
2355        None => base_freq,
2356    }
2357}
2358
2359fn scale_llama3_rope_freq(
2360    freq: f64,
2361    factor: f64,
2362    low_freq_factor: f64,
2363    high_freq_factor: f64,
2364    original_max_position_embeddings: f64,
2365) -> f64 {
2366    let wavelen = 2.0 * std::f64::consts::PI / freq;
2367    let low_freq_wavelen = original_max_position_embeddings / low_freq_factor;
2368    let high_freq_wavelen = original_max_position_embeddings / high_freq_factor;
2369    if wavelen < high_freq_wavelen {
2370        freq
2371    } else if wavelen > low_freq_wavelen {
2372        freq / factor
2373    } else {
2374        let smooth = (original_max_position_embeddings / wavelen - low_freq_factor)
2375            / (high_freq_factor - low_freq_factor);
2376        (1.0 - smooth) * freq / factor + smooth * freq
2377    }
2378}
2379
2380#[cfg(test)]
2381mod tests {
2382    use super::{LlamaFamilyRuntimeEnv, DEFAULT_KV_CAPACITY};
2383
2384    #[test]
2385    fn llama_family_runtime_env_parses_startup_knobs() {
2386        let env = LlamaFamilyRuntimeEnv::from_env_vars([
2387            ("FERRUM_KV_CAPACITY", "4096"),
2388            ("FERRUM_METAL_PAGED_KV", "0"),
2389            ("FERRUM_PAGED_MAX_SEQS", "64"),
2390            ("FERRUM_DECODE_OP_PROFILE", "0"),
2391            ("FERRUM_PREFILL_OP_PROFILE", ""),
2392            ("FERRUM_CUDA_GRAPH", ""),
2393            ("FERRUM_DECODE_LAYER_PROFILE", "false"),
2394        ]);
2395
2396        assert_eq!(env.kv_capacity, Some(4096));
2397        assert_eq!(env.metal_paged_kv, Some(false));
2398        assert_eq!(env.paged_max_seqs, 64);
2399        assert!(env.decode_op_profile);
2400        assert!(env.prefill_op_profile);
2401        assert!(env.cuda_graph);
2402        assert!(env.decode_layer_profile);
2403        assert_eq!(env.kv_capacity_for_model(2048), 2048);
2404    }
2405
2406    #[test]
2407    fn llama_family_runtime_env_uses_defaults_for_invalid_values() {
2408        let env = LlamaFamilyRuntimeEnv::from_env_vars([
2409            ("FERRUM_KV_CAPACITY", "bad"),
2410            ("FERRUM_PAGED_MAX_SEQS", "bad"),
2411            ("FERRUM_METAL_PAGED_KV", "1"),
2412        ]);
2413
2414        assert_eq!(env.kv_capacity, None);
2415        assert_eq!(env.metal_paged_kv, Some(true));
2416        assert_eq!(env.paged_max_seqs, 32);
2417        assert_eq!(
2418            env.kv_capacity_for_model(DEFAULT_KV_CAPACITY * 2),
2419            DEFAULT_KV_CAPACITY
2420        );
2421    }
2422}