Skip to main content

ferrum_models/models/
llama_family.rs

1//! Llama-family decoder model as explicit code.
2//!
3//! Covers all "standard Llama-style GQA + SwiGLU + RoPE" decoders:
4//!   - Llama / Llama-2 / Llama-3  (no QK-norm)
5//!   - Qwen2 / Qwen2.5            (no QK-norm, structurally Llama)
6//!   - Qwen3                      (QK-norm per head, larger rope_theta)
7//!   - Mistral                    (sliding-window attention — not yet
8//!                                 supported on the forward path; will
9//!                                 require an `AttnConfig.sliding_window`
10//!                                 field + shader support in Phase D)
11//!
12//! Variant differences are controlled by `LlamaFamilyConfig::has_qk_norm`
13//! and RoPE theta. Weight loading accepts both fused (`qkv_proj`,
14//! `gate_up_proj`) and split (`q_proj`+`k_proj`+`v_proj`,
15//! `gate_proj`+`up_proj`) projection layouts — the loader (e.g.
16//! `CandleVarBuilderLoader`) fuses split weights on load so model code
17//! sees a uniform `qkv_proj` / `gate_up_proj` Linear.
18
19use std::collections::HashMap;
20use std::sync::{atomic::AtomicU64, OnceLock};
21
22use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16, KvInt8};
23use ferrum_kernels::backend::{
24    Backend, BackendGraph, BackendInt8KvOps, BackendMoeFused, BackendPagedKv, BackendQuantGguf,
25    BackendQuantMarlin, KvCache, KvLayer, LlmBackend, MoeLlmBackend, QuantLlmBackend,
26    MAX_LAYERS_FOR_GRAPH,
27};
28
29/// Graph cache key for the single-item decode path (`decode_internal`).
30/// Distinct from any `m_padded`-based key used by the batched path.
31pub(crate) const SINGLE_ITEM_GRAPH_KEY: u64 = 0;
32
33/// Diag counters for the batched graph dispatcher (replay vs eager).
34pub(crate) static BATCHED_GRAPH_REPLAY_COUNT: AtomicU64 = AtomicU64::new(0);
35pub(crate) static BATCHED_GRAPH_EAGER_COUNT: AtomicU64 = AtomicU64::new(0);
36
37pub(crate) static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
38pub(crate) static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
39pub(crate) static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
40pub(crate) static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
41pub(crate) static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
42pub(crate) static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
43pub(crate) static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
44pub(crate) static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
45pub(crate) static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
46pub(crate) static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
47use ferrum_quantization::{Linear, WeightLoader};
48use ferrum_types::Result;
49
50use crate::common::paged_pool::block_hash_chain;
51use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
52use crate::lora::{load_runtime_lora_adapter, ActiveLoraAdapter, RuntimeLoraAdapter};
53
54const DEFAULT_KV_CAPACITY: usize = 512;
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57struct LlamaFamilyRuntimeEnv {
58    kv_capacity: Option<usize>,
59    metal_paged_kv: Option<bool>,
60    paged_max_seqs: usize,
61    decode_op_profile: bool,
62    prefill_op_profile: bool,
63    prefix_cache: bool,
64    cuda_graph: bool,
65    decode_layer_profile: bool,
66}
67
68impl LlamaFamilyRuntimeEnv {
69    fn from_env() -> Self {
70        Self::from_env_vars(std::env::vars())
71    }
72
73    fn from_env_vars<I, K, V>(vars: I) -> Self
74    where
75        I: IntoIterator<Item = (K, V)>,
76        K: AsRef<str>,
77        V: AsRef<str>,
78    {
79        let mut config = Self {
80            kv_capacity: None,
81            metal_paged_kv: None,
82            paged_max_seqs: 32,
83            decode_op_profile: false,
84            prefill_op_profile: false,
85            prefix_cache: false,
86            cuda_graph: false,
87            decode_layer_profile: false,
88        };
89        for (name, value) in vars {
90            let value = value.as_ref();
91            match name.as_ref() {
92                "FERRUM_KV_CAPACITY" => config.kv_capacity = value.parse::<usize>().ok(),
93                "FERRUM_METAL_PAGED_KV" => config.metal_paged_kv = Some(value != "0"),
94                "FERRUM_PAGED_MAX_SEQS" => {
95                    if let Ok(max_seqs) = value.parse::<usize>() {
96                        config.paged_max_seqs = max_seqs;
97                    }
98                }
99                "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
100                "FERRUM_PREFILL_OP_PROFILE" => config.prefill_op_profile = true,
101                "FERRUM_PREFIX_CACHE" => config.prefix_cache = value == "1",
102                "FERRUM_CUDA_GRAPH" => config.cuda_graph = true,
103                "FERRUM_DECODE_LAYER_PROFILE" => config.decode_layer_profile = true,
104                _ => {}
105            }
106        }
107        config
108    }
109
110    fn kv_capacity_for_model(&self, model_max: usize) -> usize {
111        self.kv_capacity
112            .map(|cap| cap.min(model_max))
113            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
114    }
115
116    fn paged_kv_enabled<B: BackendPagedKv>(&self) -> bool {
117        self.metal_paged_kv
118            .unwrap_or_else(|| B::supports_paged_kv())
119    }
120}
121
122fn llama_family_runtime_env() -> &'static LlamaFamilyRuntimeEnv {
123    static CONFIG: OnceLock<LlamaFamilyRuntimeEnv> = OnceLock::new();
124    CONFIG.get_or_init(LlamaFamilyRuntimeEnv::from_env)
125}
126
127#[derive(Clone, Debug, PartialEq)]
128pub enum RopeScalingConfig {
129    /// Meta Llama 3.1/3.2/3.3 long-context RoPE scaling.
130    Llama3 {
131        factor: f64,
132        low_freq_factor: f64,
133        high_freq_factor: f64,
134        original_max_position_embeddings: f64,
135    },
136}
137
138impl RopeScalingConfig {
139    pub fn llama3_default() -> Self {
140        Self::Llama3 {
141            factor: 8.0,
142            low_freq_factor: 1.0,
143            high_freq_factor: 4.0,
144            original_max_position_embeddings: 8192.0,
145        }
146    }
147}
148
149/// Full Qwen3 architecture config (everything the model code needs, not just
150/// the engine-facing subset in `LlmRuntimeConfig`).
151#[derive(Clone, Debug, PartialEq)]
152pub struct LlamaFamilyConfig {
153    pub hidden_size: usize,
154    pub intermediate_size: usize,
155    pub num_heads: usize,
156    pub num_kv_heads: usize,
157    pub head_dim: usize,
158    pub num_layers: usize,
159    pub vocab_size: usize,
160    pub max_seq_len: usize,
161    pub rms_norm_eps: f32,
162    pub rope_theta: f64,
163    pub rope_scaling: Option<RopeScalingConfig>,
164    /// GGUF LLaMA stores Q/K in the llama.cpp interleaved RoPE layout.
165    /// HF safetensors Qwen/LLaMA definitions keep the default half-split
166    /// GPT-NeoX layout used by existing Ferrum kernels.
167    pub rope_interleaved: bool,
168    /// Whether the checkpoint has `q_norm` / `k_norm` per layer. All known
169    /// Qwen3 checkpoints do; some derivatives may strip them.
170    pub has_qk_norm: bool,
171    /// Sliding-window attention size. `0` disables (full causal).
172    /// Mistral v0.1 sets 4096; Mistral v0.2+ removed the limitation (0).
173    pub sliding_window: usize,
174}
175
176impl LlamaFamilyConfig {
177    pub fn to_runtime(&self) -> LlmRuntimeConfig {
178        LlmRuntimeConfig {
179            hidden_size: self.hidden_size,
180            num_layers: self.num_layers,
181            num_kv_heads: self.num_kv_heads,
182            head_dim: self.head_dim,
183            vocab_size: self.vocab_size,
184            max_seq_len: self.max_seq_len,
185        }
186    }
187
188    /// Build config from a `ModelDefinition`, shared field extraction.
189    /// Variant-specific constructors below set `has_qk_norm` and fall back
190    /// to different `rope_theta` defaults when the checkpoint doesn't set one.
191    fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
192        let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
193        let head_dim = def
194            .extra_params
195            .get("head_dim")
196            .and_then(|v| v.as_u64())
197            .map(|v| v as usize)
198            .unwrap_or(def.hidden_size / def.num_attention_heads);
199        // Mistral / Gemma: "sliding_window" may be null (v0.2+) or a positive
200        // integer (v0.1). Non-null value passes through; missing/null → 0.
201        let sliding_window = def
202            .extra_params
203            .get("sliding_window")
204            .and_then(|v| v.as_u64())
205            .map(|v| v as usize)
206            .unwrap_or(0);
207
208        LlamaFamilyConfigBase {
209            hidden_size: def.hidden_size,
210            intermediate_size: def.intermediate_size,
211            num_heads: def.num_attention_heads,
212            num_kv_heads,
213            head_dim,
214            num_layers: def.num_hidden_layers,
215            vocab_size: def.vocab_size,
216            max_seq_len: def.max_position_embeddings,
217            rms_norm_eps: def.norm_eps as f32,
218            rope_theta_opt: def.rope_theta,
219            rope_scaling: rope_scaling_from_model_def(def),
220            sliding_window,
221        }
222    }
223
224    fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
225        Self {
226            hidden_size: b.hidden_size,
227            intermediate_size: b.intermediate_size,
228            num_heads: b.num_heads,
229            num_kv_heads: b.num_kv_heads,
230            head_dim: b.head_dim,
231            num_layers: b.num_layers,
232            vocab_size: b.vocab_size,
233            max_seq_len: b.max_seq_len,
234            rms_norm_eps: b.rms_norm_eps,
235            rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
236            rope_scaling: b.rope_scaling,
237            rope_interleaved: false,
238            has_qk_norm,
239            sliding_window: b.sliding_window,
240        }
241    }
242
243    /// Qwen3: QK-norm on, rope_theta default 1e6.
244    pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
245        Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
246    }
247
248    /// Llama / Llama-2 / Llama-3: no QK-norm; rope_theta varies by version
249    /// (10k for Llama-2, 500k for Llama-3.1+) — use the checkpoint value or
250    /// fall back to the most common modern value.
251    pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
252        Self::from_base(Self::from_def_base(def), 500_000.0, false)
253    }
254
255    /// Qwen2 / Qwen2.5: structurally Llama; no QK-norm; rope_theta default 1e6.
256    pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
257        Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
258    }
259
260    /// Mistral: no QK-norm; `rope_theta` commonly 10_000 (v0.1 / v0.2).
261    /// Picks up `sliding_window` from the checkpoint's config.json
262    /// (Mistral v0.1: 4096; Mistral v0.2+: 0 / null).
263    pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
264        Self::from_base(Self::from_def_base(def), 10_000.0, false)
265    }
266}
267
268struct LlamaFamilyConfigBase {
269    hidden_size: usize,
270    intermediate_size: usize,
271    num_heads: usize,
272    num_kv_heads: usize,
273    head_dim: usize,
274    num_layers: usize,
275    vocab_size: usize,
276    max_seq_len: usize,
277    rms_norm_eps: f32,
278    rope_theta_opt: Option<f64>,
279    rope_scaling: Option<RopeScalingConfig>,
280    sliding_window: usize,
281}
282
283fn rope_scaling_from_model_def(
284    def: &crate::definition::ModelDefinition,
285) -> Option<RopeScalingConfig> {
286    let value = def.extra_params.get("rope_scaling")?;
287    let obj = value.as_object()?;
288    let rope_type = obj
289        .get("rope_type")
290        .or_else(|| obj.get("type"))
291        .and_then(|v| v.as_str())?;
292    if rope_type != "llama3" {
293        return None;
294    }
295    let factor = json_f64(obj.get("factor"))?;
296    let low_freq_factor = json_f64(obj.get("low_freq_factor"))?;
297    let high_freq_factor = json_f64(obj.get("high_freq_factor"))?;
298    let original_max_position_embeddings = json_f64(obj.get("original_max_position_embeddings"))
299        .or_else(|| {
300            def.extra_params
301                .get("original_max_position_embeddings")
302                .and_then(|v| json_f64(Some(v)))
303        })
304        .unwrap_or(8192.0);
305    if factor <= 0.0
306        || low_freq_factor <= 0.0
307        || high_freq_factor <= low_freq_factor
308        || original_max_position_embeddings <= 0.0
309    {
310        return None;
311    }
312    Some(RopeScalingConfig::Llama3 {
313        factor,
314        low_freq_factor,
315        high_freq_factor,
316        original_max_position_embeddings,
317    })
318}
319
320fn json_f64(value: Option<&serde_json::Value>) -> Option<f64> {
321    match value? {
322        serde_json::Value::Number(n) => n.as_f64(),
323        _ => None,
324    }
325}
326
327/// Per-layer weights. `Box<dyn Linear<B>>` means each projection can be
328/// Dense / GPTQ / AWQ / GGUF without the surrounding code caring.
329pub struct LlamaFamilyLayer<B: QuantLlmBackend + BackendMoeFused> {
330    pub input_ln_w: B::Buffer,
331    pub qkv_proj: Box<dyn Linear<B>>,
332    /// QK-norm weight per head: `[head_dim]`. Optional for non-Qwen3 derivatives.
333    pub q_norm_w: Option<B::Buffer>,
334    pub k_norm_w: Option<B::Buffer>,
335    pub o_proj: Box<dyn Linear<B>>,
336    pub post_ln_w: B::Buffer,
337    pub gate_up_proj: Box<dyn Linear<B>>,
338    pub down_proj: Box<dyn Linear<B>>,
339}
340
341/// Precomputed RoPE cos/sin tables (shape `[max_seq, head_dim / 2]` each).
342pub struct RopeCache<B: QuantLlmBackend + BackendMoeFused> {
343    pub cos: B::Buffer,
344    pub sin: B::Buffer,
345}
346
347/// Reusable per-layer scratch buffers sized for `max_tokens` tokens of a
348/// single forward pass (prefill or decode step).
349///
350/// Sized lazily on first use so tiny decode steps don't pay for prefill-sized
351/// buffers. Grows monotonically when a larger prefill arrives.
352pub struct LlamaFamilyScratch<B: QuantLlmBackend + BackendMoeFused> {
353    /// Residual stream — wrapped in Option so decode_internal can
354    /// `.take()` it without needing an alloc placeholder.
355    ///
356    /// Why this matters for graph capture: the old pattern was
357    /// `mem::replace(&mut scratch.residual, B::alloc(1))` which creates a
358    /// 1-element buffer at every decode step. When graph capture is on,
359    /// that alloc-during-capture + drop-after-capture pair surfaces as
360    /// cuMemFreeAsync(INVALID_VALUE) because the free tries to release a
361    /// pointer the captured graph may still reference. Option::take leaves
362    /// None and moves the real buffer into a local — no spurious alloc.
363    pub residual: Option<B::Buffer>,
364    pub norm_out: B::Buffer,
365    pub qkv_out: B::Buffer,
366    // ── Per-item scratch for batched decode path ──────────────────────
367    // decode_batch_internal runs tokens=M batched ops for the GEMM-heavy
368    // half (norm, qkv_proj, split_qkv, o_proj, post_norm, gate_up, silu,
369    // down, residual_add) but must loop per-item for rope + KV append +
370    // attention (each item has its own KV cache at a different kv_len).
371    // These single-item buffers hold item i's slice during that loop.
372    /// Item-scope q_buf slice, sized `q_dim`.
373    pub q_single: B::Buffer,
374    pub k_single: B::Buffer,
375    pub v_single: B::Buffer,
376    pub q_head_major_single: B::Buffer,
377    pub k_head_major_single: B::Buffer,
378    pub v_head_major_single: B::Buffer,
379    pub attn_head_major_single: B::Buffer,
380    pub attn_flat_single: B::Buffer,
381    /// Batched logits output, sized `max_tokens * vocab_size`. Used only
382    /// in decode_batch; prefill/single-decode use the regular `logits`.
383    pub batch_logits: B::Buffer,
384    /// Token-major Q/K/V right after `split_qkv`. Stride: heads * hd per row.
385    pub q_buf: B::Buffer,
386    pub k_buf: B::Buffer,
387    pub v_buf: B::Buffer,
388    /// Head-major Q produced by `qk_norm_rope` — fed into `flash_attention`.
389    pub q_head_major: B::Buffer,
390    /// Head-major K/V staging — produced by `qk_norm_rope`, consumed by
391    /// `kv_cache_append_head_major` (no reuse after append).
392    pub k_head_major: B::Buffer,
393    pub v_head_major: B::Buffer,
394    /// Head-major attention output from `flash_attention`.
395    pub attn_head_major_out: B::Buffer,
396    /// Token-major attention output after `transpose_head_to_token`.
397    pub attn_flat: B::Buffer,
398    pub o_proj_out: B::Buffer,
399    pub gate_up_out: B::Buffer,
400    pub silu_out: B::Buffer,
401    pub mlp_out: B::Buffer,
402    /// Paged batched dispatch scratch (Phase 4b). Sized for
403    /// `FERRUM_PAGED_MAX_SEQS × q_dim` so multi-seq decode can fan
404    /// in M items' Q into a single buffer for one batched
405    /// `paged_decode_attention(num_seqs=M)` call. `None` when paged
406    /// mode is off.
407    pub paged_batch_q: Option<B::Buffer>,
408    pub paged_batch_o: Option<B::Buffer>,
409    /// Stacked per-seq block tables for batched paged dispatch.
410    /// Layout: `[max_M, max_blocks_per_seq]` u32. Written
411    /// host-side per decode_batch step.
412    pub paged_batch_block_tables: Option<B::Buffer>,
413    /// Stacked per-seq context lengths for batched paged dispatch
414    /// (`[max_M]` u32).
415    pub paged_batch_context_lens: Option<B::Buffer>,
416    /// `max_blocks_per_seq` value baked into the stacked block_tables
417    /// stride. Set when `paged_batch_block_tables` is allocated.
418    pub paged_max_blocks_per_seq: usize,
419    /// Engine-side max concurrent sequences (= `FERRUM_PAGED_MAX_SEQS`).
420    /// Pinned at the first `enable_paged_batch` so the unified-forward
421    /// scratch sizes (`unified_cu_seqlens_q`, `unified_pos_offsets`,
422    /// `unified_block_tables`, `unified_packed_*`) are big enough for
423    /// any subsequent batch up to that bound.
424    pub paged_max_seqs: usize,
425    /// Per-item RoPE positions for the batched-decode path (`[max_M]`
426    /// i32 / u32). Written host-side once per batched-decode step from
427    /// each request's `pos` field, read by the batched
428    /// `qk_norm_rope_batched_per_item` CUDA kernel.
429    pub batch_positions: B::Buffer,
430    /// Per-item input token ids for the batched-decode path (`[max_M]`
431    /// u32). Written once per call before forward; read by the
432    /// graph-capture-friendly `embedding_lookup_batched_dyn` variant.
433    pub batch_tokens: B::Buffer,
434    /// Per-item KV-cache length BEFORE this step's kv_append
435    /// (`[max_M]` u32). Used by `kv_cache_append_batched_per_cache_dyn`
436    /// to write at the right slot for graph replay.
437    pub batch_kv_lens_pre: B::Buffer,
438    /// Per-item KV-cache length AFTER this step's kv_append
439    /// (`[max_M]` u32 = pre + 1). Used by
440    /// `flash_attention_batched_per_cache_dyn` for the attention
441    /// reduce window length.
442    pub batch_kv_lens_post: B::Buffer,
443    /// Output buffers for the batched per-item qk_norm_rope kernel.
444    /// Same shape as q_buf / k_buf / v_buf — separate so the kernel
445    /// API can take `&input` and `&mut output` without aliasing.
446    pub q_normed_batched: B::Buffer,
447    pub k_normed_batched: B::Buffer,
448    pub v_normed_batched: B::Buffer,
449
450    // ── Unified mixed-batch scratch (chunked-prefill path; Step 5b) ─────
451    // Buffers sized for `M_total = sum(items[i].q_tokens.len())`. Grown
452    // on demand by `ensure_unified_scratch(M_total_max)`. Used only by
453    // the new `unified_forward_internal`; legacy `forward_layer_batched_decode`
454    // continues to use the per-item-stride scratch above.
455    pub unified_capacity: usize, // current allocated M_total slots
456    pub unified_residual: Option<B::Buffer>,
457    pub unified_norm_out: Option<B::Buffer>,
458    pub unified_qkv_out: Option<B::Buffer>,
459    pub unified_packed_q: Option<B::Buffer>,
460    pub unified_attn_out: Option<B::Buffer>,
461    pub unified_o_proj_out: Option<B::Buffer>,
462    pub unified_gate_up_out: Option<B::Buffer>,
463    pub unified_silu_out: Option<B::Buffer>,
464    pub unified_mlp_out: Option<B::Buffer>,
465    /// Per-item index buffers (i32-stored-as-f16): cu_seqlens_q is
466    /// length `max_seqs+1`, pos_offsets is `max_seqs`, block_tables is
467    /// `max_seqs * max_blocks_per_seq`. Sized once at first use to
468    /// match `paged_batch_*` capacity since they share `max_seqs`.
469    pub unified_cu_seqlens_q: Option<B::Buffer>,
470    pub unified_pos_offsets: Option<B::Buffer>,
471    pub unified_block_tables: Option<B::Buffer>,
472    /// Packed last-token hidden states for is_final_chunk items
473    /// (`[num_sampled, h]`). Used as input to lm_head.
474    pub unified_packed_normed: Option<B::Buffer>,
475    /// Packed logits output (`[num_sampled, vocab]`).
476    pub unified_packed_logits: Option<B::Buffer>,
477    /// Last token's hidden state (`[h]`). For prefill this is populated via
478    /// `copy_slice(residual, (seq_len-1)*h, ..)`; for decode `residual` already
479    /// holds only 1 row so `last_hidden` is unused on that path.
480    pub last_hidden: B::Buffer,
481    /// Final-norm output for the last token (`[h]`).
482    pub last_normed: B::Buffer,
483    /// lm_head logits (`[vocab]`).
484    pub logits: B::Buffer,
485    /// The max tokens-per-step this scratch has been sized for.
486    pub max_tokens: usize,
487}
488
489impl<B: QuantLlmBackend + BackendMoeFused> LlamaFamilyScratch<B> {
490    fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
491        let h = cfg.hidden_size;
492        let im = cfg.intermediate_size;
493        let q_dim = cfg.num_heads * cfg.head_dim;
494        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
495        let qkv_dim = q_dim + 2 * kv_dim;
496        let t = max_tokens;
497        Self {
498            residual: Some(B::alloc(t * h)),
499            norm_out: B::alloc(t * h),
500            qkv_out: B::alloc(t * qkv_dim),
501            q_buf: B::alloc(t * q_dim),
502            k_buf: B::alloc(t * kv_dim),
503            v_buf: B::alloc(t * kv_dim),
504            q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
505            k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
506            v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
507            attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
508            attn_flat: B::alloc(t * q_dim),
509            o_proj_out: B::alloc(t * h),
510            gate_up_out: B::alloc(t * 2 * im),
511            silu_out: B::alloc(t * im),
512            mlp_out: B::alloc(t * h),
513            last_hidden: B::alloc(h),
514            last_normed: B::alloc(h),
515            logits: B::alloc(cfg.vocab_size),
516            q_single: B::alloc(q_dim),
517            k_single: B::alloc(kv_dim),
518            v_single: B::alloc(kv_dim),
519            q_head_major_single: B::alloc(q_dim),
520            k_head_major_single: B::alloc(kv_dim),
521            v_head_major_single: B::alloc(kv_dim),
522            attn_head_major_single: B::alloc(q_dim),
523            attn_flat_single: B::alloc(q_dim),
524            batch_logits: B::alloc(t * cfg.vocab_size),
525            // Paged batched dispatch scratch. None until `enable_paged_batch`
526            // is called from `ensure_kv` once the model knows max_seqs +
527            // max_blocks_per_seq. This avoids paying the alloc cost when
528            // paged mode is off.
529            paged_batch_q: None,
530            paged_batch_o: None,
531            paged_batch_block_tables: None,
532            paged_batch_context_lens: None,
533            paged_max_blocks_per_seq: 0,
534            paged_max_seqs: 0,
535            batch_positions: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
536            batch_tokens: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
537            batch_kv_lens_pre: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
538            batch_kv_lens_post: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
539            q_normed_batched: B::alloc(t * q_dim),
540            k_normed_batched: B::alloc(t * kv_dim),
541            v_normed_batched: B::alloc(t * kv_dim),
542            unified_capacity: 0,
543            unified_residual: None,
544            unified_norm_out: None,
545            unified_qkv_out: None,
546            unified_packed_q: None,
547            unified_attn_out: None,
548            unified_o_proj_out: None,
549            unified_gate_up_out: None,
550            unified_silu_out: None,
551            unified_mlp_out: None,
552            unified_cu_seqlens_q: None,
553            unified_pos_offsets: None,
554            unified_block_tables: None,
555            unified_packed_normed: None,
556            unified_packed_logits: None,
557            max_tokens: t,
558        }
559    }
560
561    /// Grow unified-path scratch buffers to accommodate `m_total` query
562    /// tokens. Called lazily from `unified_forward_internal` so single-
563    /// path workloads (no chunked prefill) don't pay the alloc cost.
564    pub(crate) fn ensure_unified_scratch(
565        &mut self,
566        cfg: &LlamaFamilyConfig,
567        m_total: usize,
568        max_seqs: usize,
569        max_blocks_per_seq: usize,
570    ) {
571        if m_total <= self.unified_capacity
572            && self.unified_residual.is_some()
573            && self.unified_cu_seqlens_q.is_some()
574        {
575            return;
576        }
577        let cap = m_total.max(self.unified_capacity).max(1);
578        let h = cfg.hidden_size;
579        let q_dim = cfg.num_heads * cfg.head_dim;
580        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
581        let qkv_dim = q_dim + 2 * kv_dim;
582        let im = cfg.intermediate_size;
583        let v = cfg.vocab_size;
584        self.unified_residual = Some(B::alloc(cap * h));
585        self.unified_norm_out = Some(B::alloc(cap * h));
586        self.unified_qkv_out = Some(B::alloc(cap * qkv_dim));
587        self.unified_packed_q = Some(B::alloc(cap * q_dim));
588        self.unified_attn_out = Some(B::alloc(cap * q_dim));
589        self.unified_o_proj_out = Some(B::alloc(cap * h));
590        self.unified_gate_up_out = Some(B::alloc(cap * 2 * im));
591        self.unified_silu_out = Some(B::alloc(cap * im));
592        self.unified_mlp_out = Some(B::alloc(cap * h));
593        if self.unified_cu_seqlens_q.is_none() {
594            self.unified_cu_seqlens_q = Some(B::alloc_typed(
595                ferrum_kernels::backend::Dtype::U32,
596                max_seqs + 1,
597            ));
598            self.unified_pos_offsets = Some(B::alloc_typed(
599                ferrum_kernels::backend::Dtype::U32,
600                max_seqs,
601            ));
602            self.unified_block_tables = Some(B::alloc_typed(
603                ferrum_kernels::backend::Dtype::U32,
604                max_seqs * max_blocks_per_seq,
605            ));
606            self.unified_packed_normed = Some(B::alloc(max_seqs * h));
607            self.unified_packed_logits = Some(B::alloc(max_seqs * v));
608        }
609        self.unified_capacity = cap;
610    }
611
612    /// Allocate scratch for batched paged dispatch (Phase 4b). Called
613    /// lazily from `ensure_kv` once paged mode is enabled and we know
614    /// the pool dimensions. Idempotent.
615    fn enable_paged_batch(
616        &mut self,
617        cfg: &LlamaFamilyConfig,
618        max_seqs: usize,
619        max_blocks_per_seq: usize,
620    ) {
621        if self.paged_batch_q.is_some() {
622            return;
623        }
624        let q_dim = cfg.num_heads * cfg.head_dim;
625        self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
626        self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
627        self.paged_batch_block_tables = Some(B::alloc_typed(
628            ferrum_kernels::backend::Dtype::U32,
629            max_seqs * max_blocks_per_seq,
630        ));
631        self.paged_batch_context_lens = Some(B::alloc_typed(
632            ferrum_kernels::backend::Dtype::U32,
633            max_seqs,
634        ));
635        self.paged_max_blocks_per_seq = max_blocks_per_seq;
636        self.paged_max_seqs = max_seqs;
637    }
638}
639
640/// Qwen3 model — decoder-only LLM, one per (backend, weights) combination.
641///
642/// Holds all parameters, scratch space, RoPE cache, and per-sequence KV caches.
643///
644/// `B: BackendGraph + BackendQuantMarlin + BackendQuantGguf` because the decode hot path uses CUDA Graph capture/replay
645/// when the backend supports it; non-graph backends (Metal/CPU) inherit no-op
646/// defaults, so this bound is satisfied by every concrete `Backend`.
647///
648/// `K: KvDtypeKind = KvFp16` (Dim 5): selects the KV cache element type.
649/// `K = KvFp16` constructs only `LayerKvCache::Fp16(...)` variants;
650/// `K = KvInt8` constructs only `LayerKvCache::Int8(...)` variants. The
651/// `B: BackendKvDtype<KvInt8>` bound is structurally required by the
652/// enum and is satisfied by all backends via stub impls (Cpu/Metal) or
653/// the real impl (CUDA).
654pub struct LlamaFamilyModel<B: MoeLlmBackend, K: KvLayer<B> = KvFp16> {
655    pub cfg: LlamaFamilyConfig,
656    pub runtime_cfg: LlmRuntimeConfig,
657
658    /// Token embedding table. `None` for backbone-only models (e.g. the
659    /// Qwen3-TTS Talker, which embeds inputs externally and feeds via
660    /// `prefill_from_embeds`).
661    pub embed: Option<B::Buffer>,
662    pub layers: Vec<LlamaFamilyLayer<B>>,
663    pub final_norm_w: B::Buffer,
664    /// LM output head. `None` for backbone-only models.
665    pub lm_head: Option<Box<dyn Linear<B>>>,
666
667    pub rope: RopeCache<B>,
668    pub scratch: LlamaFamilyScratch<B>,
669
670    /// Per-sequence KV caches, one `Vec<KvCache<B>>` of length `num_layers`.
671    ///
672    /// Two layouts overlay this same map:
673    /// - **Contiguous mode** (default): each cache holds its own
674    ///   `[num_kv_heads, capacity, head_dim]` k/v buffers.
675    /// - **Paged mode** (`FERRUM_METAL_PAGED_KV=1`): k/v are unused
676    ///   placeholders; the real K/V live in [`Self::paged_pools`] and
677    ///   the cache's `block_table` + `context_lens` index into them.
678    pub kv_caches: HashMap<String, Vec<K::Layer>>,
679    /// Free pool of pre-allocated KV cache slots. Released caches return
680    /// here instead of being dropped, so their device pointers stay valid
681    /// across requests — critical for graph capture (pointers baked into
682    /// the captured graph would otherwise dangle).
683    kv_free_pool: Vec<Vec<K::Layer>>,
684
685    // ── Paged-KV multi-seq state (Phase 4) ─────────────────────────────
686    //
687    // Only populated when `FERRUM_METAL_PAGED_KV=1`. When set, every
688    // `kv_caches` entry becomes a "view" into the shared pool: its
689    // `k` / `v` buffers are placeholders; reads / writes go through
690    // `paged_pools[layer].k` / `.v` indexed via the cache's
691    // `block_table`. Multiple cache_ids share the same pool, with
692    // disjoint physical block sets owned by `paged_block_alloc`.
693    //
694    /// Shared K/V pools, one per layer. Sized at model load time for the
695    /// configured `MAX_RUNNING_REQUESTS × max_blocks_per_seq` blocks.
696    pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
697    /// Block allocator hands out physical block indices from the pool.
698    /// `Mutex` because the engine can call `ensure_kv` / `release_kv`
699    /// from multiple threads in concurrent serving.
700    pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
701    /// Paged-batch dispatch dimensions `(max_seqs, max_blocks_per_seq)`,
702    /// pinned at the first `ensure_kv` when paged-KV is on. Stored on
703    /// the model (not on scratch) so `ensure_scratch`'s realloc can
704    /// re-call `enable_paged_batch` after wiping scratch's
705    /// `paged_batch_block_tables` / `paged_batch_q` etc.
706    pub paged_dims: Option<(usize, usize)>,
707
708    // ── Graph capture state (CUDA only; harmless no-op on other backends) ──
709    /// Count of eager decode steps run so far. After `GRAPH_WARMUP`, the
710    /// next step captures the decode flow as a graph.
711    pub(crate) graph_warmup: usize,
712    /// True if capture was attempted but failed (e.g. backend doesn't
713    /// support graph capture). Stops further attempts, falls back to eager.
714    pub(crate) graph_capture_failed: bool,
715    /// Same warmup counter for the batched-decode path.
716    pub(crate) batched_graph_warmup: usize,
717    /// True if batched graph capture failed.
718    pub(crate) batched_graph_failed: bool,
719    /// Set of `m_padded` values (as u64 graph keys) for which a batched
720    /// graph has been captured. Multi-slot via cuda.rs's HashMap-keyed
721    /// graph cache — different batch shapes don't thrash a single slot.
722    pub(crate) batched_graph_keys_seen: std::collections::HashSet<u64>,
723    /// Cache IDs for which device-pointer scratch is currently populated.
724    /// Populate only re-runs when the batch composition changes (new
725    /// requests joined / requests finished). Hot-path optimization:
726    /// avoids 3 sync cuMemcpyHtoD_v2's per decode token (~5% TPOT).
727    pub(crate) batched_pointers_for: Option<Vec<String>>,
728    /// CUDA-graph state for the unified_forward path. Mirrors the
729    /// `batched_graph_*` triple but keyed on `(m_total, num_seqs)`
730    /// so different concurrency levels each get their own cached
731    /// graph instead of thrashing a single slot.
732    pub(crate) unified_graph_warmup: usize,
733    pub(crate) unified_graph_failed: bool,
734    pub(crate) unified_graph_keys_seen: std::collections::HashSet<u64>,
735
736    // ── Real paged-KV prefix cache counters ─────────────────────────────
737    prefix_cache_hits: u64,
738    prefix_cache_misses: u64,
739    prefix_cache_saved_prefill_tokens: u64,
740
741    // ── Startup LoRA runtime state ──────────────────────────────────────
742    lora_adapters: HashMap<String, RuntimeLoraAdapter<B>>,
743    lora_cache_adapters: HashMap<String, String>,
744    lora_projection_applications: u64,
745}
746
747impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyModel<B, K> {
748    /// Build a Qwen3 model from weights provided by the loader.
749    ///
750    /// The loader decides per-projection whether to instantiate DenseLinear,
751    /// GptqLinear, etc. — this code doesn't care.
752    pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
753        // Invalidate any graph from a previously-loaded model. The captured
754        // graph references the old model's scratch buffers; a fresh model
755        // gets fresh scratch, so reusing the graph would read/write freed
756        // pointers. Matters for test suites where multiple models coexist.
757        {
758            let mut ctx = B::new_context();
759            B::reset_all_graphs(&mut ctx);
760        }
761        let rope = build_rope_cache::<B>(&cfg);
762        let scratch = LlamaFamilyScratch::alloc(&cfg, 1); // decode-sized; prefill resizes
763
764        // Embedding: plain tensor (no projection math, just lookup).
765        let embed = loader.load_tensor("model.embed_tokens.weight")?;
766
767        // Per-layer weights.
768        let mut layers = Vec::with_capacity(cfg.num_layers);
769        for li in 0..cfg.num_layers {
770            let prefix = format!("model.layers.{li}");
771            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
772            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
773            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
774            let post_ln_w =
775                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
776            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
777            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
778
779            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
780                let q = loader
781                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
782                    .ok();
783                let k = loader
784                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
785                    .ok();
786                (q, k)
787            } else {
788                (None, None)
789            };
790
791            layers.push(LlamaFamilyLayer {
792                input_ln_w,
793                qkv_proj,
794                q_norm_w,
795                k_norm_w,
796                o_proj,
797                post_ln_w,
798                gate_up_proj,
799                down_proj,
800            });
801        }
802
803        let final_norm_w = loader.load_tensor("model.norm.weight")?;
804
805        // LM head: either dedicated `lm_head.weight` or tied to embedding.
806        // Many models (Qwen3-4B, Llama-3.2-1B, some Qwen2.5) use TIED
807        // embeddings — lm_head shares weights with model.embed_tokens. When
808        // no dedicated lm_head tensor exists, re-load the embed tensor as a
809        // DenseLinear. This duplicates the buffer (memory cost = vocab*h*2
810        // bytes, e.g. ~770MB for Qwen3-4B) but keeps the Linear trait's
811        // owned-weights invariant. Sharing via Arc is a future optimisation.
812        let lm_head = if loader.has_tensor("lm_head.weight") {
813            loader.load_linear("lm_head")?
814        } else {
815            tracing::info!(
816                "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
817            );
818            let as_linear = loader.load_linear("model.embed_tokens")?;
819            // Sanity check: shape must be [vocab, hidden].
820            if as_linear.out_features() != cfg.vocab_size
821                || as_linear.in_features() != cfg.hidden_size
822            {
823                return Err(ferrum_types::FerrumError::model(format!(
824                    "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
825                    as_linear.out_features(),
826                    as_linear.in_features(),
827                    cfg.vocab_size,
828                    cfg.hidden_size
829                )));
830            }
831            as_linear
832        };
833
834        let runtime_cfg = cfg.to_runtime();
835        Ok(Self {
836            cfg,
837            runtime_cfg,
838            embed: Some(embed),
839            layers,
840            final_norm_w,
841            lm_head: Some(lm_head),
842            rope,
843            scratch,
844            kv_caches: HashMap::new(),
845            kv_free_pool: Vec::new(),
846            paged_pools: None,
847            paged_block_alloc: None,
848            paged_dims: None,
849            graph_warmup: 0,
850            graph_capture_failed: false,
851            batched_graph_warmup: 0,
852            batched_graph_failed: false,
853            batched_graph_keys_seen: std::collections::HashSet::new(),
854            batched_pointers_for: None,
855            unified_graph_warmup: 0,
856            unified_graph_failed: false,
857            unified_graph_keys_seen: std::collections::HashSet::new(),
858            prefix_cache_hits: 0,
859            prefix_cache_misses: 0,
860            prefix_cache_saved_prefill_tokens: 0,
861            lora_adapters: HashMap::new(),
862            lora_cache_adapters: HashMap::new(),
863            lora_projection_applications: 0,
864        })
865    }
866
867    /// Build a backbone-only Qwen3 transformer stack (no embed, no lm_head).
868    ///
869    /// Intended for composing the transformer inside a larger model where
870    /// embedding and output-head logic differs from the standard LLM path —
871    /// e.g. Qwen3-TTS Talker uses dual text/codec embeddings with a projection
872    /// MLP, and a codec_head output. The caller drives forward via
873    /// `prefill_from_embeds` / `decode_from_embed`.
874    ///
875    /// Loader must provide: per-layer weights under `model.layers.{i}.*` and
876    /// the final `model.norm.weight`. `model.embed_tokens` and `lm_head`
877    /// are NOT read.
878    pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
879        // See `new` — invalidate stale graph referring to prior model's scratch.
880        {
881            let mut ctx = B::new_context();
882            B::reset_all_graphs(&mut ctx);
883        }
884        let rope = build_rope_cache::<B>(&cfg);
885        let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
886
887        let mut layers = Vec::with_capacity(cfg.num_layers);
888        for li in 0..cfg.num_layers {
889            let prefix = format!("model.layers.{li}");
890            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
891            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
892            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
893            let post_ln_w =
894                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
895            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
896            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
897
898            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
899                let q = loader
900                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
901                    .ok();
902                let k = loader
903                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
904                    .ok();
905                (q, k)
906            } else {
907                (None, None)
908            };
909
910            layers.push(LlamaFamilyLayer {
911                input_ln_w,
912                qkv_proj,
913                q_norm_w,
914                k_norm_w,
915                o_proj,
916                post_ln_w,
917                gate_up_proj,
918                down_proj,
919            });
920        }
921
922        let final_norm_w = loader.load_tensor("model.norm.weight")?;
923
924        let runtime_cfg = cfg.to_runtime();
925        Ok(Self {
926            cfg,
927            runtime_cfg,
928            embed: None,
929            layers,
930            final_norm_w,
931            lm_head: None,
932            rope,
933            scratch,
934            kv_caches: HashMap::new(),
935            kv_free_pool: Vec::new(),
936            paged_pools: None,
937            paged_block_alloc: None,
938            paged_dims: None,
939            graph_warmup: 0,
940            graph_capture_failed: false,
941            batched_graph_warmup: 0,
942            batched_graph_failed: false,
943            batched_graph_keys_seen: std::collections::HashSet::new(),
944            batched_pointers_for: None,
945            unified_graph_warmup: 0,
946            unified_graph_failed: false,
947            unified_graph_keys_seen: std::collections::HashSet::new(),
948            prefix_cache_hits: 0,
949            prefix_cache_misses: 0,
950            prefix_cache_saved_prefill_tokens: 0,
951            lora_adapters: HashMap::new(),
952            lora_cache_adapters: HashMap::new(),
953            lora_projection_applications: 0,
954        })
955    }
956
957    /// Grow scratch buffers if `tokens` exceeds the current sizing.
958    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
959        if self.scratch.max_tokens < tokens {
960            // Any captured decode graph holds pointers to the old scratch
961            // buffers; those are about to be freed. Invalidate ALL captured
962            // graphs (both single-item and per-m_padded batched) — every
963            // captured kernel-arg pointer into scratch is stale.
964            {
965                let mut ctx = B::new_context();
966                B::reset_all_graphs(&mut ctx);
967            }
968            self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
969            self.graph_warmup = 0;
970            self.graph_capture_failed = false;
971            self.batched_graph_keys_seen.clear();
972            self.batched_graph_warmup = 0;
973            self.batched_graph_failed = false;
974            self.unified_graph_keys_seen.clear();
975            self.unified_graph_warmup = 0;
976            self.unified_graph_failed = false;
977            // Realloc wiped paged_batch_*. Re-enable using the dims
978            // pinned at first ensure_kv. Without this, the next
979            // `forward_layer_batched_decode` panics on
980            // `paged_batch_block_tables missing`.
981            if let Some((max_seqs, max_blocks_per_seq)) = self.paged_dims {
982                self.scratch
983                    .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
984            }
985        }
986    }
987
988    /// Ensure per-layer KV caches exist for `cache_id`, pre-allocated to
989    /// `max_seq_len` slots per head. Enables the in-place
990    /// `kv_cache_append_head_major` path — no realloc per layer.
991    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
992        if self.kv_caches.contains_key(cache_id) {
993            return;
994        }
995        let nkv = self.cfg.num_kv_heads;
996        let hd = self.cfg.head_dim;
997        // KV capacity defaults to a chat-friendly 4096 to keep the working
998        // set sane on a 32 GB Mac (a 48-layer / 4-kv-head / 128-head_dim
999        // model spends ~786 MB on 4096 KV slots, vs ~6 GB on the model's
1000        // declared 32K which would push the 17 GB Qwen3-30B-A3B model into
1001        // swap). `FERRUM_KV_CAPACITY=N` overrides; clamp to the model's
1002        // declared max so we never lie to the model about its window.
1003        let model_max = self.cfg.max_seq_len;
1004        // 512 in 0.7.2 — matches the value used in
1005        // docs/bench/macos-2026-05-02 to get the published numbers.
1006        // pre-0.7.2 default of 4096 was safe only because paged-KV was
1007        // opt-in (pool wasn't allocated). With paged-KV now on by
1008        // default + MAX_SEQS=32, the pool occupies physical memory:
1009        // ~3 GB on Qwen3-30B-A3B Q4_K_M leaves 18 GB weights + 3 GB pool
1010        // = 21 GB, fits comfortably on a 32 GB Mac. Long-context users
1011        // can `FERRUM_KV_CAPACITY=4096` and accept lower max_seqs.
1012        let runtime_env = llama_family_runtime_env();
1013        let max = runtime_env.kv_capacity_for_model(model_max);
1014
1015        // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches every cache
1016        // for this model into block-table-indirect layout. Kernels from
1017        // PR #68 (decode read) + PR #69 (decode write) handle the
1018        // indirect addressing; the LlamaFamily decode path below picks
1019        // them up automatically by checking `cache.block_size > 0`.
1020        //
1021        // Pool sizing: round capacity up to a multiple of block_size,
1022        // identity-assign logical→physical block. Memory footprint is
1023        // the same as contiguous (within block_size rounding); the
1024        // benefit only shows up under multi-seq sharing in Phase 4+.
1025        // Default ON when the backend supports paged-KV (Metal). Users
1026        // can force off with `FERRUM_METAL_PAGED_KV=0`. The flag was
1027        // opt-in pre-0.7.2; flipping the default so default `ferrum
1028        // serve` matches the bench-quality numbers without requiring
1029        // env-var knowledge.
1030        let paged = runtime_env.paged_kv_enabled::<B>();
1031        const PAGED_BLOCK_SIZE: usize = 16;
1032
1033        // Phase 4 shared-pool sizing. The pool sees ALL concurrent
1034        // sequences; per-cache_id state just owns indices into it.
1035        // Default 32: covers c=16 burst with 2× headroom for the
1036        // fresh-cache-id-per-request pattern that bench/server harnesses
1037        // use. Pool memory is `max_seqs × max_blocks_per_seq` total
1038        // blocks — we lowered DEFAULT_KV_CAPACITY to 2048 so this 2× max_seqs
1039        // bump keeps the pool footprint identical to the pre-0.7.2 default.
1040        let max_seqs = runtime_env.paged_max_seqs;
1041        let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
1042        let total_pool_blocks = max_seqs * max_blocks_per_seq;
1043
1044        // Lazy-allocate the shared paged pools on the FIRST paged
1045        // ensure_kv call. Pools are big — for Llama-8B (8 kv_heads,
1046        // head_dim=128) at 16 seqs × 256 blocks × 16 slots = 65536 KV
1047        // slots: 65536 * 8 * 128 * 4 = 256 MB per layer × 32 layers
1048        // = 8 GB total. Sized this large only because `max_seqs=16`
1049        // is the default; lower it via env to shrink.
1050        if paged && self.paged_pools.is_none() {
1051            let mut pools = Vec::with_capacity(self.cfg.num_layers);
1052            for _ in 0..self.cfg.num_layers {
1053                let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
1054                pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
1055            }
1056            self.paged_pools = Some(pools);
1057            self.paged_block_alloc = Some(std::sync::Mutex::new(
1058                crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
1059            ));
1060        }
1061        // Phase 4b: ensure batched-dispatch scratch is allocated whenever
1062        // paged is on. Idempotent — re-init is a no-op if already
1063        // sized. Has to live outside the `paged_pools.is_none()` branch
1064        // because `ensure_scratch` may have replaced the scratch struct
1065        // since the pools were first allocated.
1066        if paged {
1067            self.scratch
1068                .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
1069            // Pin dims on the model so `ensure_scratch`'s realloc can
1070            // re-init paged_batch_* after wiping scratch.
1071            self.paged_dims = Some((max_seqs, max_blocks_per_seq));
1072        }
1073
1074        // Try pool first — reused buffers have stable device pointers,
1075        // so a captured decode graph can be replayed for this request too.
1076        // K::NAME selects which `LayerKvCache` variant to construct:
1077        // K-aware allocation: K::alloc_paged / K::alloc_contig pick the
1078        // right cache layout (FP16 → KvCache, INT8 → KvCacheQuant) per the
1079        // model's K marker. INT8 supports paged mode only — KvInt8::alloc_contig
1080        // panics, surfacing the misconfiguration here at first ensure_kv.
1081        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
1082            (0..self.cfg.num_layers)
1083                .map(|_| {
1084                    if paged {
1085                        K::alloc_paged(max_blocks_per_seq, PAGED_BLOCK_SIZE, nkv, hd)
1086                    } else {
1087                        K::alloc_contig(max, nkv, hd)
1088                    }
1089                })
1090                .collect()
1091        });
1092
1093        // Allocate physical blocks for THIS cache_id from the shared
1094        // pool. We allocate all `max_blocks_per_seq` upfront for
1095        // simplicity (matches contig's "pre-alloc to capacity"
1096        // semantics); a smarter Phase 4b can grow on demand to save
1097        // pool occupancy.
1098        if paged {
1099            let alloc_arc = self
1100                .paged_block_alloc
1101                .as_ref()
1102                .expect("paged_block_alloc must be initialised when paged=true");
1103            // Recover from a previously-poisoned mutex instead of panicking
1104            // (poison just means a prior holder panicked; the BlockAllocator
1105            // state is still intact since allocate_n is fail-safe).
1106            let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1107            let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
1108                Ok(idx) => idx,
1109                Err(e) => {
1110                    // Pool exhaustion is a back-pressure signal, not a crash.
1111                    // Drop the lock, return the cache to the free pool, and
1112                    // bail before inserting it into kv_caches. The downstream
1113                    // call will then fail with a clean per-request error
1114                    // ("ensure_kv must be called before ...") instead of
1115                    // dragging every other in-flight request down with it.
1116                    drop(alloc);
1117                    self.kv_free_pool.push(caches);
1118                    eprintln!(
1119                        "[ferrum] paged KV pool exhausted on ensure_kv for \
1120                         cache_id={cache_id:?}: {e}. Increase \
1121                         FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
1122                         throttle concurrent requests.",
1123                    );
1124                    return;
1125                }
1126            };
1127            // Write the block table to each layer's cache. All layers
1128            // share the same logical→physical mapping for this seq.
1129            // Also stash the host-side index list so release_kv can
1130            // return them to the allocator without a device readback.
1131            let mut padded = block_indices.clone();
1132            padded.resize(max_blocks_per_seq, 0);
1133            let mut ctx_tmp = B::new_context();
1134            for c in caches.iter_mut() {
1135                if let Some(bt) = K::block_table_mut(c) {
1136                    B::write_typed::<u32>(&mut ctx_tmp, bt, &padded);
1137                }
1138                *K::paged_block_indices_mut(c) = block_indices.clone();
1139            }
1140            B::sync(&mut ctx_tmp);
1141        }
1142
1143        // Reset logical length; buffers stay. No need to zero the memory —
1144        // the kv_cache_append writes new K/V in place, and attention only
1145        // reads up to `cache_len`.
1146        for c in caches.iter_mut() {
1147            K::set_len(c, 0);
1148            if let Some(cl) = K::context_lens_mut(c) {
1149                let mut ctx_tmp = B::new_context();
1150                B::write_typed::<u32>(&mut ctx_tmp, cl, &[0u32]);
1151                B::sync(&mut ctx_tmp);
1152            }
1153        }
1154        self.kv_caches.insert(cache_id.to_string(), caches);
1155    }
1156
1157    fn record_prefix_cache_probe(&mut self, saved_tokens: usize) {
1158        if saved_tokens > 0 {
1159            self.prefix_cache_hits += 1;
1160            self.prefix_cache_saved_prefill_tokens += saved_tokens as u64;
1161        } else {
1162            self.prefix_cache_misses += 1;
1163        }
1164    }
1165
1166    fn try_acquire_prefix_cache(&mut self, cache_id: &str, tokens: &[u32]) -> usize {
1167        let Some(alloc_arc) = self.paged_block_alloc.as_ref() else {
1168            return 0;
1169        };
1170        let caches = match self.kv_caches.get(cache_id) {
1171            Some(caches) => caches,
1172            None => return 0,
1173        };
1174        let block_size = caches.first().map(K::block_size).unwrap_or(0);
1175        if block_size == 0 {
1176            return 0;
1177        }
1178
1179        let token_ids: Vec<ferrum_types::TokenId> = tokens
1180            .iter()
1181            .map(|&token| ferrum_types::TokenId::new(token))
1182            .collect();
1183        let hashes = block_hash_chain(&token_ids, block_size);
1184        if hashes.is_empty() {
1185            return 0;
1186        }
1187
1188        let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1189        let mut matched = Vec::with_capacity(hashes.len());
1190        for hash in hashes {
1191            match alloc.try_acquire_by_hash(hash) {
1192                Some(block) => matched.push(block),
1193                None => break,
1194            }
1195        }
1196        if matched.is_empty() {
1197            return 0;
1198        }
1199        let n_matched = matched.len();
1200
1201        let displaced = caches
1202            .first()
1203            .map(|cache| K::paged_block_indices(cache)[..n_matched].to_vec())
1204            .unwrap_or_default();
1205        alloc.free(&displaced);
1206        drop(alloc);
1207
1208        let caches_mut = self.kv_caches.get_mut(cache_id).expect("cache present");
1209        let max_blocks = caches_mut
1210            .first()
1211            .map(|cache| K::paged_block_indices(cache).len())
1212            .unwrap_or(0);
1213        let new_len = n_matched * block_size;
1214        let mut ctx = B::new_context();
1215        for cache in caches_mut.iter_mut() {
1216            {
1217                let indices = K::paged_block_indices_mut(cache);
1218                for (idx, &block) in matched.iter().enumerate() {
1219                    indices[idx] = block;
1220                }
1221            }
1222            K::set_len(cache, new_len);
1223            let padded = {
1224                let mut padded = K::paged_block_indices(cache).to_vec();
1225                padded.resize(max_blocks, 0);
1226                padded
1227            };
1228            if let Some(block_table) = K::block_table_mut(cache) {
1229                B::write_typed::<u32>(&mut ctx, block_table, &padded);
1230            }
1231            if let Some(context_lens) = K::context_lens_mut(cache) {
1232                B::write_typed::<u32>(&mut ctx, context_lens, &[new_len as u32]);
1233            }
1234        }
1235        B::sync(&mut ctx);
1236
1237        new_len
1238    }
1239
1240    fn register_prefix_cache(
1241        &mut self,
1242        cache_id: &str,
1243        all_tokens: &[u32],
1244        prior_cached_tokens: usize,
1245    ) {
1246        let Some(alloc_arc) = self.paged_block_alloc.as_ref() else {
1247            return;
1248        };
1249        let caches = match self.kv_caches.get(cache_id) {
1250            Some(caches) => caches,
1251            None => return,
1252        };
1253        let cache0 = match caches.first() {
1254            Some(cache) => cache,
1255            None => return,
1256        };
1257        let block_size = K::block_size(cache0);
1258        if block_size == 0 {
1259            return;
1260        }
1261
1262        let token_ids: Vec<ferrum_types::TokenId> = all_tokens
1263            .iter()
1264            .map(|&token| ferrum_types::TokenId::new(token))
1265            .collect();
1266        let hashes = block_hash_chain(&token_ids, block_size);
1267        if hashes.is_empty() {
1268            return;
1269        }
1270
1271        let start_block = prior_cached_tokens / block_size;
1272        let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1273        for i in start_block..hashes.len().min(K::paged_block_indices(cache0).len()) {
1274            let block_end_token = (i + 1) * block_size;
1275            if block_end_token > K::len(cache0) {
1276                break;
1277            }
1278            alloc.register_block_hash(K::paged_block_indices(cache0)[i], hashes[i]);
1279        }
1280    }
1281
1282    fn prefix_cache_snapshot_json(&self) -> serde_json::Value {
1283        let (entries, block_size) = self
1284            .paged_block_alloc
1285            .as_ref()
1286            .and_then(|alloc| {
1287                let alloc = alloc.lock().ok()?;
1288                let block_size = self
1289                    .kv_caches
1290                    .values()
1291                    .find_map(|layers| layers.first().map(K::block_size))
1292                    .unwrap_or(16);
1293                Some((alloc.hash_table_size() as u64, block_size))
1294            })
1295            .unwrap_or((0, 16));
1296        let bytes_per_entry = (block_size
1297            * self.cfg.num_layers
1298            * self.cfg.num_kv_heads
1299            * self.cfg.head_dim
1300            * K::BYTES_PER_ELEM
1301            * 2) as u64;
1302        serde_json::json!({
1303            "position": "real-kv-reuse",
1304            "source": "llama-family-paged-block-prefix-cache",
1305            "enabled": llama_family_runtime_env().prefix_cache,
1306            "hits": self.prefix_cache_hits,
1307            "misses": self.prefix_cache_misses,
1308            "evictions": 0u64,
1309            "saved_prefill_tokens": self.prefix_cache_saved_prefill_tokens,
1310            "entries": entries,
1311            "bytes": entries.saturating_mul(bytes_per_entry),
1312            "block_size": block_size,
1313            "kv_dtype": K::NAME,
1314        })
1315    }
1316
1317    fn lora_projection_shape(
1318        &self,
1319        layer_index: usize,
1320        target_module: &str,
1321    ) -> Option<(usize, usize)> {
1322        let layer = self.layers.get(layer_index)?;
1323        match target_module {
1324            "qkv_proj" => Some((layer.qkv_proj.in_features(), layer.qkv_proj.out_features())),
1325            "o_proj" => Some((layer.o_proj.in_features(), layer.o_proj.out_features())),
1326            "gate_up_proj" => Some((
1327                layer.gate_up_proj.in_features(),
1328                layer.gate_up_proj.out_features(),
1329            )),
1330            "down_proj" => Some((
1331                layer.down_proj.in_features(),
1332                layer.down_proj.out_features(),
1333            )),
1334            _ => None,
1335        }
1336    }
1337
1338    fn validate_lora_adapter(&self, adapter: &RuntimeLoraAdapter<B>) -> Result<()> {
1339        if adapter.linears.is_empty() {
1340            return Err(ferrum_types::FerrumError::config(format!(
1341                "LoRA adapter {} has no runtime tensors",
1342                adapter.name
1343            )));
1344        }
1345        for linear in &adapter.linears {
1346            let layer_index = linear.layer_index.ok_or_else(|| {
1347                ferrum_types::FerrumError::config(format!(
1348                    "LoRA tensor for target {} must include model.layers.<N> in its tensor name",
1349                    linear.target_module
1350                ))
1351            })?;
1352            let Some((expected_in, expected_out)) =
1353                self.lora_projection_shape(layer_index, &linear.target_module)
1354            else {
1355                return Err(ferrum_types::FerrumError::unsupported(format!(
1356                    "LoRA target {} is not supported by Llama-family runtime; supported targets: qkv_proj, o_proj, gate_up_proj, down_proj",
1357                    linear.target_module
1358                )));
1359            };
1360            if linear.in_features != expected_in || linear.out_features != expected_out {
1361                return Err(ferrum_types::FerrumError::config(format!(
1362                    "LoRA tensor shape mismatch for layer {} target {}: got out={} in={}, expected out={} in={}",
1363                    layer_index,
1364                    linear.target_module,
1365                    linear.out_features,
1366                    linear.in_features,
1367                    expected_out,
1368                    expected_in
1369                )));
1370            }
1371        }
1372        Ok(())
1373    }
1374
1375    fn ensure_lora_adapter_loaded(&mut self, adapter: ActiveLoraAdapter) -> Result<()> {
1376        if self.lora_adapters.contains_key(&adapter.name) {
1377            return Ok(());
1378        }
1379        let runtime = load_runtime_lora_adapter::<B>(&adapter)?;
1380        self.validate_lora_adapter(&runtime)?;
1381        self.lora_adapters.insert(adapter.name.clone(), runtime);
1382        Ok(())
1383    }
1384
1385    fn active_lora_adapter_for_cache(&self, cache_id: &str) -> Option<&RuntimeLoraAdapter<B>> {
1386        let adapter_name = self.lora_cache_adapters.get(cache_id)?;
1387        self.lora_adapters.get(adapter_name)
1388    }
1389
1390    fn active_lora_adapter_ptr_for_cache(
1391        &self,
1392        cache_id: &str,
1393    ) -> Option<*const RuntimeLoraAdapter<B>> {
1394        self.active_lora_adapter_for_cache(cache_id)
1395            .map(|adapter| adapter as *const RuntimeLoraAdapter<B>)
1396    }
1397
1398    fn lora_metrics_snapshot_json(&self) -> serde_json::Value {
1399        serde_json::json!({
1400            "enabled": !self.lora_adapters.is_empty(),
1401            "adapter_count": self.lora_adapters.len() as u64,
1402            "active_cache_bindings": self.lora_cache_adapters.len() as u64,
1403            "projection_applications": self.lora_projection_applications,
1404            "position": "real-inference",
1405            "source": "llama-family-runtime-lora",
1406        })
1407    }
1408
1409    /// Run one transformer layer. Mutates `residual` in place.
1410    ///
1411    /// `pos_offset` is the absolute position of token 0 in this batch
1412    /// (decode: `pos`; prefill: 0). `tokens` is the batch size.
1413    #[allow(clippy::too_many_arguments)]
1414    pub(crate) fn forward_layer(
1415        &mut self,
1416        ctx: &mut B::Context,
1417        li: usize,
1418        cache_id: &str,
1419        residual: &mut B::Buffer,
1420        pos_offset: usize,
1421        tokens: usize,
1422    ) {
1423        let layer = &self.layers[li];
1424        let cfg = &self.cfg;
1425        let h = cfg.hidden_size;
1426        let nh = cfg.num_heads;
1427        let nkv = cfg.num_kv_heads;
1428        let hd = cfg.head_dim;
1429        let im = cfg.intermediate_size;
1430        let eps = cfg.rms_norm_eps;
1431        let q_dim = nh * hd;
1432        let kv_dim = nkv * hd;
1433
1434        // 1. Input RMSNorm
1435        let _t0 = if llama_family_runtime_env().decode_op_profile {
1436            B::sync(ctx);
1437            Some(std::time::Instant::now())
1438        } else {
1439            None
1440        };
1441        B::rms_norm(
1442            ctx,
1443            residual,
1444            &layer.input_ln_w,
1445            eps,
1446            &mut self.scratch.norm_out,
1447            tokens,
1448            h,
1449        );
1450        if let Some(t0) = _t0 {
1451            B::sync(ctx);
1452            NORM_TIME_US.fetch_add(
1453                t0.elapsed().as_micros() as u64,
1454                std::sync::atomic::Ordering::Relaxed,
1455            );
1456            NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1457        }
1458
1459        // 2. Fused QKV projection (Linear dispatches to Dense/GPTQ/AWQ/GGUF)
1460        let _t0 = if llama_family_runtime_env().decode_op_profile {
1461            B::sync(ctx);
1462            Some(std::time::Instant::now())
1463        } else {
1464            None
1465        };
1466        layer.qkv_proj.forward(
1467            ctx,
1468            &self.scratch.norm_out,
1469            &mut self.scratch.qkv_out,
1470            tokens,
1471        );
1472        if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1473            // SAFETY: `lora_adapters` is not mutated during a forward pass;
1474            // the pointer is used only for this immediate read-only call.
1475            let applied = unsafe { &*adapter }
1476                .apply_projection(
1477                    ctx,
1478                    li,
1479                    "qkv_proj",
1480                    &self.scratch.norm_out,
1481                    &mut self.scratch.qkv_out,
1482                    tokens,
1483                )
1484                .expect("validated LoRA qkv_proj");
1485            self.lora_projection_applications += applied as u64;
1486        }
1487        if let Some(t0) = _t0 {
1488            B::sync(ctx);
1489            MATMUL_TIME_US.fetch_add(
1490                t0.elapsed().as_micros() as u64,
1491                std::sync::atomic::Ordering::Relaxed,
1492            );
1493            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1494        }
1495
1496        // 3-5. Fused split-QKV + QK-norm + RoPE + cache-write.
1497        //
1498        // Single Metal dispatch replaces the (split_qkv → 3× qk_norm_rope
1499        // → kv_cache_append_head_major) five-launch chain on the decode
1500        // hot path. Reads qkv_out once, writes Q to head-major scratch
1501        // and K/V straight into the pre-allocated KV cache slot at
1502        // `cache_len + tok`. Saves 4 dispatches per layer when the
1503        // backend implements the fused kernel; CPU and other backends
1504        // keep using the unfused chain via the Unsupported fallbacks.
1505        //
1506        // qk_mode: 1 = norm + half-split RoPE (Qwen3/Qwen HF);
1507        //          2 = half-split RoPE only;
1508        //          3 = interleaved RoPE only (GGUF LLaMA / llama.cpp layout).
1509        // V always passes apply_norm=0.
1510        let qk_mode: i32 = if cfg.has_qk_norm {
1511            1
1512        } else if cfg.rope_interleaved {
1513            3
1514        } else {
1515            2
1516        };
1517        let dummy = &layer.input_ln_w;
1518        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
1519        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
1520
1521        // Grab the per-layer KV cache up front so the deepest fusion can
1522        // write K/V straight into it.
1523        //
1524        // Paged mode: also need this layer's shared pool buffers
1525        // (self.paged_pools[li]). The pool is a separate field from
1526        // kv_caches, so we take a raw pointer to its (k, v) here while
1527        // we still hold &mut self, then deref via unsafe inside the
1528        // paged dispatch below. Safety: paged_pools is allocated once
1529        // and never resized; we don't touch self.paged_pools while the
1530        // pointer is in use.
1531        let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
1532            if let Some(pools) = self.paged_pools.as_mut() {
1533                let pool = &mut pools[li];
1534                Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
1535            } else {
1536                None
1537            };
1538        let caches = self
1539            .kv_caches
1540            .get_mut(cache_id)
1541            .expect("ensure_kv must be called before forward_layer");
1542        // Read shared metadata (variant-agnostic) once. The K-aware
1543        // attn section below re-borrows the right enum variant.
1544        let cache_len_before = K::len(&caches[li]);
1545        let cache_capacity = K::capacity(&caches[li]);
1546        let cache_block_size = K::block_size(&caches[li]);
1547
1548        // Defense in depth: refuse to write past the KV buffer. The
1549        // graceful path is the caller pre-checking via `kv_capacity()`
1550        // and either compacting or refusing the request; this panic only
1551        // fires when that contract is broken (and silent overflow would
1552        // otherwise corrupt the cache + adjacent allocations).
1553        if cache_len_before + tokens > cache_capacity {
1554            panic!(
1555                "KV cache overflow on layer {li}: would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
1556                cache_len_before + tokens
1557            );
1558        }
1559
1560        // Paged path: K::paged_write fuses split_qkv_norm_rope + cache append
1561        // (FP16: into_paged_cache; INT8: split_qkv_norm_rope + int8_kv_append_paged).
1562        // K::paged_decode_attention reads from layer-local INT8 buffers or the
1563        // shared FP16 pool depending on K. Then K-agnostic post-attn tail.
1564        if cache_block_size > 0 {
1565            let (pool_k_ptr, pool_v_ptr) =
1566                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1567            // SAFETY: paged_pools is allocated once and never resized; the
1568            // raw pointers don't outlive this method scope.
1569            let pool_k = unsafe { &mut *pool_k_ptr };
1570            let pool_v = unsafe { &mut *pool_v_ptr };
1571
1572            K::paged_write(
1573                ctx,
1574                &mut caches[li],
1575                &self.scratch.qkv_out,
1576                q_norm_w,
1577                k_norm_w,
1578                &self.rope.cos,
1579                &self.rope.sin,
1580                &mut self.scratch.q_head_major,
1581                &mut self.scratch.k_head_major,
1582                &mut self.scratch.v_head_major,
1583                pool_k,
1584                pool_v,
1585                tokens,
1586                nh,
1587                nkv,
1588                hd,
1589                pos_offset,
1590                eps,
1591                qk_mode,
1592            )
1593            .expect("K::paged_write");
1594
1595            let new_len = cache_len_before + tokens;
1596            K::set_len(&mut caches[li], new_len);
1597
1598            let pool_k_imm = unsafe { &*pool_k_ptr };
1599            let pool_v_imm = unsafe { &*pool_v_ptr };
1600            K::paged_decode_attention(
1601                ctx,
1602                &mut caches[li],
1603                &self.scratch.q_head_major,
1604                pool_k_imm,
1605                pool_v_imm,
1606                &mut self.scratch.attn_head_major_out,
1607                nh,
1608                nkv,
1609                hd,
1610                new_len,
1611                tokens,
1612            )
1613            .expect("K::paged_decode_attention");
1614
1615            return self.forward_layer_post_attn(ctx, li, cache_id, residual, tokens);
1616        }
1617
1618        // Non-paged (contig) path. INT8 path doesn't reach here:
1619        // KvInt8::alloc_contig panics in ensure_kv.
1620        let _qkr_t0 = if llama_family_runtime_env().decode_op_profile {
1621            B::sync(ctx);
1622            Some(std::time::Instant::now())
1623        } else {
1624            None
1625        };
1626        K::contig_write(
1627            ctx,
1628            &mut caches[li],
1629            &self.scratch.qkv_out,
1630            q_norm_w,
1631            k_norm_w,
1632            &self.rope.cos,
1633            &self.rope.sin,
1634            &mut self.scratch.q_head_major,
1635            &mut self.scratch.k_head_major,
1636            &mut self.scratch.v_head_major,
1637            &mut self.scratch.q_buf,
1638            &mut self.scratch.k_buf,
1639            &mut self.scratch.v_buf,
1640            tokens,
1641            nh,
1642            nkv,
1643            hd,
1644            pos_offset,
1645            eps,
1646            qk_mode,
1647        )
1648        .expect("K::contig_write");
1649        if let Some(t0) = _qkr_t0 {
1650            B::sync(ctx);
1651            QKR_TIME_US.fetch_add(
1652                t0.elapsed().as_micros() as u64,
1653                std::sync::atomic::Ordering::Relaxed,
1654            );
1655            QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1656        }
1657        let new_len = cache_len_before + tokens;
1658        K::set_len(&mut caches[li], new_len);
1659        let kv_stride = cache_capacity;
1660
1661        let _attn_t0 = if llama_family_runtime_env().decode_op_profile {
1662            B::sync(ctx);
1663            Some(std::time::Instant::now())
1664        } else {
1665            None
1666        };
1667        let attn_cfg = ferrum_kernels::backend::AttnConfig {
1668            num_heads: nh,
1669            num_kv_heads: nkv,
1670            head_dim: hd,
1671            causal: true,
1672            scale: 1.0 / (hd as f32).sqrt(),
1673            kv_seq_stride: kv_stride,
1674            sliding_window: cfg.sliding_window,
1675        };
1676        K::contig_decode_attention(
1677            ctx,
1678            &caches[li],
1679            &self.scratch.q_head_major,
1680            &mut self.scratch.attn_head_major_out,
1681            attn_cfg,
1682            tokens,
1683            pos_offset,
1684        )
1685        .expect("K::contig_decode_attention");
1686        let _ = q_dim;
1687        let _ = kv_dim;
1688        let _ = dummy;
1689        if let Some(t0) = _attn_t0 {
1690            B::sync(ctx);
1691            ATTN_TIME_US.fetch_add(
1692                t0.elapsed().as_micros() as u64,
1693                std::sync::atomic::Ordering::Relaxed,
1694            );
1695            ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1696        }
1697
1698        self.forward_layer_post_attn(ctx, li, cache_id, residual, tokens);
1699    }
1700
1701    /// Post-attention tail of `forward_layer`: untranspose Q (if needed),
1702    /// O-proj, fused residual+post_norm, gate_up_proj, SwiGLU, down_proj,
1703    /// final residual add. K-agnostic — reads `self.scratch.attn_head_major_out`
1704    /// which both the FP16 and INT8 attn paths populate.
1705    pub(crate) fn forward_layer_post_attn(
1706        &mut self,
1707        ctx: &mut B::Context,
1708        li: usize,
1709        cache_id: &str,
1710        residual: &mut B::Buffer,
1711        tokens: usize,
1712    ) {
1713        let layer = &self.layers[li];
1714        let cfg = &self.cfg;
1715        let h = cfg.hidden_size;
1716        let nh = cfg.num_heads;
1717        let hd = cfg.head_dim;
1718        let im = cfg.intermediate_size;
1719        let eps = cfg.rms_norm_eps;
1720
1721        // 7. Untranspose head-major → token-major for O-proj input.
1722        let attn_token_major = if tokens == 1 {
1723            &self.scratch.attn_head_major_out
1724        } else {
1725            B::transpose_head_to_token(
1726                ctx,
1727                &self.scratch.attn_head_major_out,
1728                &mut self.scratch.attn_flat,
1729                tokens,
1730                nh,
1731                hd,
1732            );
1733            &self.scratch.attn_flat
1734        };
1735
1736        // 8. O projection.
1737        layer
1738            .o_proj
1739            .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1740        if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1741            // SAFETY: see qkv_proj LoRA application above.
1742            let applied = unsafe { &*adapter }
1743                .apply_projection(
1744                    ctx,
1745                    li,
1746                    "o_proj",
1747                    attn_token_major,
1748                    &mut self.scratch.o_proj_out,
1749                    tokens,
1750                )
1751                .expect("validated LoRA o_proj");
1752            self.lora_projection_applications += applied as u64;
1753        }
1754
1755        // 9. Fused residual-add + post-attention RMSNorm.
1756        B::fused_add_rms_norm(
1757            ctx,
1758            residual,
1759            &self.scratch.o_proj_out,
1760            &layer.post_ln_w,
1761            eps,
1762            &mut self.scratch.norm_out,
1763            tokens,
1764            h,
1765        );
1766
1767        // 10. Fused gate+up projection.
1768        layer.gate_up_proj.forward(
1769            ctx,
1770            &self.scratch.norm_out,
1771            &mut self.scratch.gate_up_out,
1772            tokens,
1773        );
1774        if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1775            // SAFETY: see qkv_proj LoRA application above.
1776            let applied = unsafe { &*adapter }
1777                .apply_projection(
1778                    ctx,
1779                    li,
1780                    "gate_up_proj",
1781                    &self.scratch.norm_out,
1782                    &mut self.scratch.gate_up_out,
1783                    tokens,
1784                )
1785                .expect("validated LoRA gate_up_proj");
1786            self.lora_projection_applications += applied as u64;
1787        }
1788
1789        // 11. SwiGLU: silu(gate) * up.
1790        B::fused_silu_mul_split(
1791            ctx,
1792            &self.scratch.gate_up_out,
1793            &mut self.scratch.silu_out,
1794            tokens,
1795            im,
1796        );
1797
1798        // 12. Down projection.
1799        layer.down_proj.forward(
1800            ctx,
1801            &self.scratch.silu_out,
1802            &mut self.scratch.mlp_out,
1803            tokens,
1804        );
1805        if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1806            // SAFETY: see qkv_proj LoRA application above.
1807            let applied = unsafe { &*adapter }
1808                .apply_projection(
1809                    ctx,
1810                    li,
1811                    "down_proj",
1812                    &self.scratch.silu_out,
1813                    &mut self.scratch.mlp_out,
1814                    tokens,
1815                )
1816                .expect("validated LoRA down_proj");
1817            self.lora_projection_applications += applied as u64;
1818        }
1819
1820        // 13. Final residual add.
1821        B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1822    }
1823
1824    /// Multi-position decode-verify: run one forward pass over `tokens`
1825    /// starting at the cache's current end position, write their K/V
1826    /// into the KV cache, and return logits for ALL `tokens.len()`
1827    /// positions as a flat `Vec<f32>` of length `seq_len * vocab_size`.
1828    ///
1829    /// Used by speculative decoding: target receives
1830    /// `[last_token, draft_0, ..., draft_{N-1}]` (N+1 inputs) and produces
1831    /// N+1 logit rows in a single forward instead of N+1 sequential
1832    /// decode() calls. Positions are implicit — the model looks up
1833    /// `pos_offset = cache.len` the same way prefill_internal does, so
1834    /// chunked prefill semantics carry over for free.
1835    pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1836        let seq_len = tokens.len();
1837        assert!(seq_len > 0, "forward_verify called with empty tokens");
1838        self.ensure_scratch(seq_len);
1839        self.ensure_kv(cache_id);
1840
1841        let h = self.cfg.hidden_size;
1842        let vocab = self.cfg.vocab_size;
1843
1844        let pos_offset = self
1845            .kv_caches
1846            .get(cache_id)
1847            .and_then(|layers| layers.first())
1848            .map(|c| K::len(c))
1849            .unwrap_or(0);
1850
1851        let mut ctx = B::new_context();
1852        let mut residual = self
1853            .scratch
1854            .residual
1855            .take()
1856            .expect("scratch residual missing (previous call didn't restore)");
1857
1858        let embed = self
1859            .embed
1860            .as_ref()
1861            .expect("forward_verify called on backbone-only model (no embed)");
1862        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1863
1864        for li in 0..self.cfg.num_layers {
1865            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1866        }
1867
1868        // RMSNorm on ALL seq_len positions (prefill_internal only norms
1869        // the last one; verify needs the full grid).
1870        B::rms_norm(
1871            &mut ctx,
1872            &residual,
1873            &self.final_norm_w,
1874            self.cfg.rms_norm_eps,
1875            &mut self.scratch.norm_out,
1876            seq_len,
1877            h,
1878        );
1879
1880        // LM head applied to all positions → `seq_len * vocab` logits.
1881        // Reuses the existing `batch_logits` scratch (sized max_tokens *
1882        // vocab) so no extra allocation.
1883        let lm_head = self
1884            .lm_head
1885            .as_ref()
1886            .expect("forward_verify called on backbone-only model (no lm_head)");
1887        lm_head.forward(
1888            &mut ctx,
1889            &self.scratch.norm_out,
1890            &mut self.scratch.batch_logits,
1891            seq_len,
1892        );
1893
1894        B::sync(&mut ctx);
1895        self.scratch.residual = Some(residual);
1896        B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
1897    }
1898
1899    /// Prefill: process `tokens` prompt tokens in a single batch, return
1900    /// `[vocab_size]` logits for the last position.
1901    ///
1902    /// Supports incremental prefill: if the KV cache for `cache_id` already
1903    /// contains earlier tokens, the new chunk's positions are computed as
1904    /// `[kv_len, kv_len + tokens.len())` so RoPE and causal masking stay
1905    /// aligned. Used by the engine's chunked-prefill path.
1906    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1907        assert!(!tokens.is_empty(), "prefill called with empty token list");
1908        self.ensure_kv(cache_id);
1909
1910        let cache_len_before = self
1911            .kv_caches
1912            .get(cache_id)
1913            .and_then(|layers| layers.first())
1914            .map(K::len)
1915            .unwrap_or(0);
1916        let mut cached_prefix_tokens =
1917            if llama_family_runtime_env().prefix_cache && cache_len_before == 0 {
1918                self.try_acquire_prefix_cache(cache_id, tokens)
1919            } else {
1920                0
1921            };
1922        if cached_prefix_tokens >= tokens.len() {
1923            let block_size = self
1924                .kv_caches
1925                .get(cache_id)
1926                .and_then(|layers| layers.first())
1927                .map(K::block_size)
1928                .unwrap_or(16);
1929            cached_prefix_tokens = cached_prefix_tokens
1930                .saturating_sub(block_size)
1931                .min(tokens.len() - 1);
1932        }
1933        if llama_family_runtime_env().prefix_cache && cache_len_before == 0 {
1934            self.record_prefix_cache_probe(cached_prefix_tokens);
1935        }
1936
1937        if cached_prefix_tokens > 0 {
1938            let caches_mut = self.kv_caches.get_mut(cache_id).expect("cache present");
1939            let mut ctx_tmp = B::new_context();
1940            for cache in caches_mut.iter_mut() {
1941                if K::len(cache) != cached_prefix_tokens {
1942                    K::set_len(cache, cached_prefix_tokens);
1943                    if let Some(context_lens) = K::context_lens_mut(cache) {
1944                        B::write_typed::<u32>(
1945                            &mut ctx_tmp,
1946                            context_lens,
1947                            &[cached_prefix_tokens as u32],
1948                        );
1949                    }
1950                }
1951            }
1952            B::sync(&mut ctx_tmp);
1953        }
1954
1955        let suffix_tokens = &tokens[cached_prefix_tokens..];
1956        let seq_len = suffix_tokens.len();
1957        assert!(
1958            seq_len > 0,
1959            "prefix cache must leave at least one suffix token"
1960        );
1961        self.ensure_scratch(seq_len);
1962
1963        // Starting position for this chunk — 0 for a fresh prefill, kv_len
1964        // for the second+ chunk of a split prefill.
1965        let pos_offset = self
1966            .kv_caches
1967            .get(cache_id)
1968            .and_then(|layers| layers.first())
1969            .map(|c| K::len(c))
1970            .unwrap_or(0);
1971
1972        let h = self.cfg.hidden_size;
1973        let vocab = self.cfg.vocab_size;
1974        let mut ctx = B::new_context();
1975
1976        // Move `residual` out of `scratch` to work around the borrow checker:
1977        // `forward_layer` re-borrows `&mut self` to reach `self.layers` /
1978        // `self.kv_caches`, which would conflict with an outstanding
1979        // `&mut self.scratch.residual`. Use Option::take to move it out
1980        // (no placeholder alloc → no transient cuMemFreeAsync that could
1981        // corrupt stream pool state after graph ops on Blackwell).
1982        let mut residual = self
1983            .scratch
1984            .residual
1985            .take()
1986            .expect("scratch residual missing (previous call didn't restore)");
1987        let embed = self
1988            .embed
1989            .as_ref()
1990            .expect("prefill_internal called on backbone-only model (no embed)");
1991        B::embedding_lookup(&mut ctx, embed, suffix_tokens, &mut residual, h);
1992
1993        let prefill_profile = llama_family_runtime_env().prefill_op_profile;
1994        let prefill_t0 = if prefill_profile {
1995            B::sync(&mut ctx);
1996            Some(std::time::Instant::now())
1997        } else {
1998            None
1999        };
2000
2001        for li in 0..self.cfg.num_layers {
2002            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2003        }
2004
2005        if let Some(t0) = prefill_t0 {
2006            B::sync(&mut ctx);
2007            let total_us = t0.elapsed().as_micros() as u64;
2008            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2009            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2010            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2011            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2012            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2013            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2014            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2015            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2016            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2017            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2018            eprintln!(
2019                "[prefill-profile] tokens={} layers total={} ms",
2020                seq_len,
2021                total_us / 1000
2022            );
2023            let bucket = |label: &str, n: u64, us: u64| {
2024                if n > 0 {
2025                    eprintln!(
2026                        "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
2027                        n,
2028                        us / 1000,
2029                        us / n
2030                    );
2031                }
2032            };
2033            bucket("flash_attn", attn_n, attn_us);
2034            bucket("qk_norm_rope", qkr_n, qkr_us);
2035            bucket("matmuls", mm_n, mm_us);
2036            bucket("norms", norm_n, norm_us);
2037            bucket("other", other_n, other_us);
2038        }
2039
2040        // Take the last token's hidden state: residual[(seq_len-1)*h .. seq_len*h]
2041        B::copy_slice(
2042            &mut ctx,
2043            &residual,
2044            (seq_len - 1) * h,
2045            &mut self.scratch.last_hidden,
2046            0,
2047            h,
2048        );
2049
2050        // Final RMSNorm on the last hidden.
2051        B::rms_norm(
2052            &mut ctx,
2053            &self.scratch.last_hidden,
2054            &self.final_norm_w,
2055            self.cfg.rms_norm_eps,
2056            &mut self.scratch.last_normed,
2057            1,
2058            h,
2059        );
2060
2061        // LM head (m=1 — triggers GEMV on MetalBackend).
2062        let lm_head = self
2063            .lm_head
2064            .as_ref()
2065            .expect("prefill_internal called on backbone-only model (no lm_head)");
2066        lm_head.forward(
2067            &mut ctx,
2068            &self.scratch.last_normed,
2069            &mut self.scratch.logits,
2070            1,
2071        );
2072
2073        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
2074        // buffer's CPU pointer without flushing the command buffer, so the
2075        // GPU must complete all pending work first or we read stale/random
2076        // data. CUDA's to_vec does an internal stream.synchronize, making
2077        // the call redundant there (~50µs/step cost), but correctness on
2078        // Metal requires the explicit flush here.
2079        B::sync(&mut ctx);
2080
2081        // Restore residual into scratch for reuse on the next call.
2082        self.scratch.residual = Some(residual);
2083        if llama_family_runtime_env().prefix_cache && cache_len_before == 0 {
2084            self.register_prefix_cache(cache_id, tokens, cached_prefix_tokens);
2085        }
2086
2087        B::to_vec(&self.scratch.logits, vocab)
2088    }
2089
2090    /// Decode: process 1 token at position `pos`, return `[vocab_size]` logits.
2091    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2092        self.ensure_scratch(1);
2093        self.ensure_kv(cache_id);
2094
2095        let h = self.cfg.hidden_size;
2096        let vocab = self.cfg.vocab_size;
2097
2098        // Context creation is cheap (CUDA reuses the process-global stream).
2099        // The captured graph lives in a process-global slot, not on ctx.
2100        let mut ctx = B::new_context();
2101
2102        // Graph capture is opt-in via FERRUM_CUDA_GRAPH=1. Replay is currently
2103        // single-request-only on Blackwell + CUDA 12.8 (see
2104        // docs/phase-e-cuda-status.md). In pure eager mode, we skip the
2105        // per-step device-state memcpy_htod trio entirely.
2106        const GRAPH_WARMUP: usize = 3;
2107        let graph_enabled = llama_family_runtime_env().cuda_graph;
2108
2109        if graph_enabled {
2110            // Refresh device-side dynamic state (token/pos/kv_len) before
2111            // replay — captured graph reads these from device buffers.
2112            B::set_decode_state(&mut ctx, token, pos);
2113
2114            // Fast path: graph replay (if available). Single-item path
2115            // uses key=SINGLE_ITEM_GRAPH_KEY (0) — separate from the
2116            // batched path's m_padded keys.
2117            match B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY) {
2118                Ok(true) => {
2119                    B::sync(&mut ctx);
2120                    return B::to_vec(&self.scratch.logits, vocab);
2121                }
2122                Ok(false) => { /* no graph yet, fall through to eager */ }
2123                Err(_) => { /* backend error or unsupported, eager */ }
2124            }
2125        }
2126
2127        let should_capture =
2128            graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
2129
2130        if should_capture {
2131            B::set_dev_state_mode(&mut ctx, true);
2132            if B::begin_graph_capture(&mut ctx).is_err() {
2133                self.graph_capture_failed = true;
2134                B::set_dev_state_mode(&mut ctx, false);
2135            }
2136        }
2137
2138        // Eager forward (records into graph if capture is active).
2139        // mem::replace needs a placeholder. B::alloc(0) was our choice but
2140        // cuMemAllocFromPoolAsync(stream, 0) can return CUDA_ERROR_INVALID_VALUE
2141        // on Blackwell after graph replay corrupts the pool state. Size-1 is
2142        // always valid and costs 2 bytes of transient VRAM per decode step.
2143        let mut residual = self
2144            .scratch
2145            .residual
2146            .take()
2147            .expect("scratch residual missing (previous call didn't restore)");
2148        let embed = self
2149            .embed
2150            .as_ref()
2151            .expect("decode_internal called on backbone-only model (no embed)");
2152        B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
2153
2154        // Per-layer wall-time profile (env-gated, off by default — adds
2155        // a B::sync between layers which serializes the pipeline). Helps
2156        // localise non-matmul bottlenecks during perf work.
2157        let layer_profile = llama_family_runtime_env().decode_layer_profile;
2158        let mut layer_times = if layer_profile {
2159            Some(Vec::with_capacity(self.cfg.num_layers))
2160        } else {
2161            None
2162        };
2163
2164        for li in 0..self.cfg.num_layers {
2165            if layer_profile {
2166                let t0 = std::time::Instant::now();
2167                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2168                B::sync(&mut ctx);
2169                let elapsed_us = t0.elapsed().as_micros() as u64;
2170                if let Some(v) = layer_times.as_mut() {
2171                    v.push(elapsed_us);
2172                }
2173            } else {
2174                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2175            }
2176        }
2177        if let Some(times) = layer_times.take() {
2178            let sum: u64 = times.iter().sum();
2179            let avg = sum / times.len() as u64;
2180            let mn = *times.iter().min().unwrap_or(&0);
2181            let mx = *times.iter().max().unwrap_or(&0);
2182            eprintln!(
2183                "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
2184                times.len(),
2185                sum / 1000,
2186                avg,
2187                mn,
2188                mx
2189            );
2190            for (i, t) in times.iter().enumerate() {
2191                eprint!("L{i}={}ms ", t / 1000);
2192                if (i + 1) % 6 == 0 {
2193                    eprintln!();
2194                }
2195            }
2196            eprintln!();
2197            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2198            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2199            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2200            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2201            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2202            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2203            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2204            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2205            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2206            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2207            eprintln!(
2208                "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
2209                attn_n,
2210                attn_us / 1000,
2211                if attn_n > 0 { attn_us / attn_n } else { 0 }
2212            );
2213            eprintln!(
2214                "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
2215                qkr_n,
2216                qkr_us / 1000,
2217                if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
2218            );
2219            eprintln!(
2220                "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
2221                mm_n,
2222                mm_us / 1000,
2223                if mm_n > 0 { mm_us / mm_n } else { 0 }
2224            );
2225            eprintln!(
2226                "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
2227                norm_n,
2228                norm_us / 1000,
2229                if norm_n > 0 { norm_us / norm_n } else { 0 }
2230            );
2231            eprintln!(
2232                "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
2233                other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
2234            );
2235        }
2236
2237        B::rms_norm(
2238            &mut ctx,
2239            &residual,
2240            &self.final_norm_w,
2241            self.cfg.rms_norm_eps,
2242            &mut self.scratch.last_normed,
2243            1,
2244            h,
2245        );
2246
2247        let lm_head = self
2248            .lm_head
2249            .as_ref()
2250            .expect("decode_internal called on backbone-only model (no lm_head)");
2251        lm_head.forward(
2252            &mut ctx,
2253            &self.scratch.last_normed,
2254            &mut self.scratch.logits,
2255            1,
2256        );
2257
2258        if should_capture && !self.graph_capture_failed {
2259            if B::end_graph_capture(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
2260                self.graph_capture_failed = true;
2261            } else {
2262                // Stream capture mode RECORDS ops into the graph without
2263                // executing them. scratch.logits still holds the previous
2264                // step's value. Replay the just-captured graph once to
2265                // actually execute and produce this step's logits. Without
2266                // this, the capture step's to_vec returns stale logits,
2267                // yielding a 1-token offset in the generated sequence.
2268                if B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
2269                    self.graph_capture_failed = true;
2270                }
2271            }
2272            B::set_dev_state_mode(&mut ctx, false);
2273        } else {
2274            self.graph_warmup += 1;
2275        }
2276
2277        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
2278        // buffer's CPU pointer without flushing the command buffer, so the
2279        // GPU must complete all pending work first or we read stale/random
2280        // data. CUDA's to_vec does an internal stream.synchronize, making
2281        // the call redundant there (~50µs/step cost), but correctness on
2282        // Metal requires the explicit flush here.
2283        B::sync(&mut ctx);
2284        self.scratch.residual = Some(residual);
2285
2286        B::to_vec(&self.scratch.logits, vocab)
2287    }
2288
2289    /// Prefill with pre-computed embeddings instead of token IDs.
2290    ///
2291    /// Used by models that embed inputs outside the LLM (e.g. Qwen3-TTS
2292    /// mixes text-embedding + codec-embedding before feeding the LM).
2293    /// Skips `final_norm` + `lm_head`; returns the last position's pre-norm
2294    /// hidden state. Caller applies its own output head.
2295    ///
2296    /// `embeds` is row-major `[seq_len * hidden_size]`, f32.
2297    pub fn prefill_from_embeds(
2298        &mut self,
2299        cache_id: &str,
2300        embeds: &[f32],
2301        seq_len: usize,
2302    ) -> Vec<f32> {
2303        let h = self.cfg.hidden_size;
2304        assert_eq!(
2305            embeds.len(),
2306            seq_len * h,
2307            "embeds length {} != seq_len * hidden_size {}",
2308            embeds.len(),
2309            seq_len * h
2310        );
2311        assert!(seq_len > 0, "prefill_from_embeds called with zero length");
2312
2313        self.ensure_scratch(seq_len);
2314        self.ensure_kv(cache_id);
2315
2316        let mut ctx = B::new_context();
2317        let mut residual = self
2318            .scratch
2319            .residual
2320            .take()
2321            .expect("scratch residual missing (previous call didn't restore)");
2322
2323        // Upload embeds → residual[0 .. seq_len*h].
2324        let embed_buf = B::from_slice(embeds);
2325        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2326
2327        for li in 0..self.cfg.num_layers {
2328            self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
2329        }
2330
2331        B::copy_slice(
2332            &mut ctx,
2333            &residual,
2334            (seq_len - 1) * h,
2335            &mut self.scratch.last_hidden,
2336            0,
2337            h,
2338        );
2339        B::sync(&mut ctx);
2340        self.scratch.residual = Some(residual);
2341        B::to_vec(&self.scratch.last_hidden, h)
2342    }
2343
2344    /// Decode with a single pre-computed embedding (shape `[hidden]`).
2345    /// Returns the pre-norm hidden state for the position `pos`. Caller
2346    /// applies final norm + its own output head.
2347    pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
2348        let h = self.cfg.hidden_size;
2349        assert_eq!(
2350            embed.len(),
2351            h,
2352            "embed length {} != hidden_size {}",
2353            embed.len(),
2354            h
2355        );
2356
2357        self.ensure_scratch(1);
2358        self.ensure_kv(cache_id);
2359
2360        let mut ctx = B::new_context();
2361        let mut residual = self
2362            .scratch
2363            .residual
2364            .take()
2365            .expect("scratch residual missing (previous call didn't restore)");
2366
2367        let embed_buf = B::from_slice(embed);
2368        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2369
2370        for li in 0..self.cfg.num_layers {
2371            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2372        }
2373
2374        B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
2375        B::sync(&mut ctx);
2376        self.scratch.residual = Some(residual);
2377        B::to_vec(&self.scratch.last_hidden, h)
2378    }
2379
2380    /// Variant of `prefill_from_embeds` that applies `final_norm` to every
2381    /// position and returns the whole `[seq_len * hidden_size]` vector.
2382    /// Accepts `pos_offset` so callers can continue an existing sequence
2383    /// (e.g. Qwen3-TTS voice-clone: one prefill for the role prefix, a
2384    /// follow-up prefill for the reference-audio ICL block, then
2385    /// autoregressive decoding — all against the same KV cache).
2386    ///
2387    /// Used by TTS where `forward_step` in the candle-based wrapper is
2388    /// expected to return **post-norm all-positions** hidden state so
2389    /// `codec_head` can be applied on candle side.
2390    pub fn prefill_all_post_norm(
2391        &mut self,
2392        cache_id: &str,
2393        embeds: &[f32],
2394        seq_len: usize,
2395        pos_offset: usize,
2396    ) -> Vec<f32> {
2397        let h = self.cfg.hidden_size;
2398        assert_eq!(
2399            embeds.len(),
2400            seq_len * h,
2401            "embeds length {} != seq_len * hidden_size {}",
2402            embeds.len(),
2403            seq_len * h
2404        );
2405        assert!(seq_len > 0);
2406
2407        self.ensure_scratch(seq_len);
2408        self.ensure_kv(cache_id);
2409
2410        let mut ctx = B::new_context();
2411        let mut residual = self
2412            .scratch
2413            .residual
2414            .take()
2415            .expect("scratch residual missing (previous call didn't restore)");
2416
2417        let embed_buf = B::from_slice(embeds);
2418        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2419
2420        for li in 0..self.cfg.num_layers {
2421            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2422        }
2423
2424        // Apply final_norm over all seq_len positions → scratch.norm_out.
2425        B::rms_norm(
2426            &mut ctx,
2427            &residual,
2428            &self.final_norm_w,
2429            self.cfg.rms_norm_eps,
2430            &mut self.scratch.norm_out,
2431            seq_len,
2432            h,
2433        );
2434        B::sync(&mut ctx);
2435        self.scratch.residual = Some(residual);
2436        B::to_vec(&self.scratch.norm_out, seq_len * h)
2437    }
2438
2439    /// Decode-side companion to `prefill_all_post_norm`. Runs a single-token
2440    /// decode step at `pos`, applies `final_norm`, and returns the post-norm
2441    /// hidden state `[hidden_size]`.
2442    pub fn decode_post_norm_from_embed(
2443        &mut self,
2444        cache_id: &str,
2445        embed: &[f32],
2446        pos: u32,
2447    ) -> Vec<f32> {
2448        let h = self.cfg.hidden_size;
2449        assert_eq!(embed.len(), h);
2450
2451        self.ensure_scratch(1);
2452        self.ensure_kv(cache_id);
2453
2454        let mut ctx = B::new_context();
2455        let mut residual = self
2456            .scratch
2457            .residual
2458            .take()
2459            .expect("scratch residual missing (previous call didn't restore)");
2460
2461        let embed_buf = B::from_slice(embed);
2462        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2463
2464        for li in 0..self.cfg.num_layers {
2465            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2466        }
2467
2468        B::rms_norm(
2469            &mut ctx,
2470            &residual,
2471            &self.final_norm_w,
2472            self.cfg.rms_norm_eps,
2473            &mut self.scratch.last_normed,
2474            1,
2475            h,
2476        );
2477        B::sync(&mut ctx);
2478        self.scratch.residual = Some(residual);
2479        B::to_vec(&self.scratch.last_normed, h)
2480    }
2481}
2482
2483// FP16 DecoderOnlyLLM impl — full path with batched + unified-forward overrides.
2484impl<B: MoeLlmBackend> DecoderOnlyLLM for LlamaFamilyModel<B, KvFp16> {
2485    fn config(&self) -> &LlmRuntimeConfig {
2486        &self.runtime_cfg
2487    }
2488
2489    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
2490        Some(self.prefix_cache_snapshot_json())
2491    }
2492
2493    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
2494        Some(self.lora_metrics_snapshot_json())
2495    }
2496
2497    fn set_lora_adapter_for_cache(
2498        &mut self,
2499        cache_id: &str,
2500        adapter: Option<ActiveLoraAdapter>,
2501    ) -> std::result::Result<(), ferrum_types::FerrumError> {
2502        if let Some(adapter) = adapter {
2503            self.ensure_lora_adapter_loaded(adapter.clone())?;
2504            self.lora_cache_adapters
2505                .insert(cache_id.to_string(), adapter.name);
2506        } else {
2507            self.lora_cache_adapters.remove(cache_id);
2508        }
2509        Ok(())
2510    }
2511
2512    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2513        self.ensure_scratch(max_tokens);
2514        self.ensure_kv(cache_id);
2515
2516        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2517        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2518        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2519            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2520                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2521                if let Some(c0) = caches.first() {
2522                    if !c0.paged_block_indices.is_empty() {
2523                        alloc.free(&c0.paged_block_indices);
2524                    }
2525                }
2526                for c in caches.iter_mut() {
2527                    c.paged_block_indices.clear();
2528                }
2529            }
2530            self.kv_free_pool.push(caches);
2531        }
2532    }
2533
2534    fn kv_capacity(&self) -> usize {
2535        llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2536    }
2537
2538    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2539        self.prefill_internal(cache_id, tokens)
2540    }
2541
2542    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2543        self.decode_internal(cache_id, token, pos)
2544    }
2545
2546    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2547        self.decode_batch_with_full_logits(batch, false)
2548    }
2549
2550    fn decode_batch_with_full_logits(
2551        &mut self,
2552        batch: &[(String, u32, u32)],
2553        force_full_logits: bool,
2554    ) -> Vec<Vec<f32>> {
2555        self.decode_batch_internal_with_full_logits(batch, force_full_logits)
2556    }
2557
2558    fn unified_forward(
2559        &mut self,
2560        items: &[(String, Vec<u32>, usize, bool)],
2561    ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
2562        if items.is_empty() {
2563            return Ok(Vec::new());
2564        }
2565        if llama_family_runtime_env().prefix_cache
2566            && items
2567                .iter()
2568                .any(|(_, tokens, pos_offset, _)| *pos_offset == 0 && tokens.len() > 1)
2569        {
2570            return Err(ferrum_types::FerrumError::unsupported(
2571                "LlamaFamilyModel::unified_forward: fresh prefill with prefix cache enabled \
2572                 routes through prefill_internal so real paged-block KV reuse can probe and \
2573                 register block hashes",
2574            ));
2575        }
2576        if !B::supports_varlen_qkv() {
2577            return Err(ferrum_types::FerrumError::unsupported(
2578                "LlamaFamilyModel::unified_forward: backend lacks varlen \
2579                 QKV kernels. Engine will fall back to per-item dispatch.",
2580            ));
2581        }
2582        if items
2583            .iter()
2584            .any(|(cache_id, _, _, _)| self.active_lora_adapter_for_cache(cache_id).is_some())
2585        {
2586            return Err(ferrum_types::FerrumError::unsupported(
2587                "LlamaFamilyModel::unified_forward: active LoRA adapter routes through \
2588                 per-item dispatch until unified LoRA supports row-selective adapters.",
2589            ));
2590        }
2591        self.ensure_kv(&items[0].0);
2592        if self.paged_pools.is_none() {
2593            return Err(ferrum_types::FerrumError::unsupported(
2594                "LlamaFamilyModel::unified_forward: paged KV required; \
2595                 enable via FERRUM_METAL_PAGED_KV=1 (cross-backend env). \
2596                 Engine will fall back to per-item dispatch.",
2597            ));
2598        }
2599        for (cid, _, _, _) in items {
2600            self.ensure_kv(cid);
2601            if !self.kv_caches.contains_key(cid) {
2602                return Err(ferrum_types::FerrumError::resource_exhausted(format!(
2603                    "paged KV pool exhausted for cache_id={cid:?}; back off"
2604                )));
2605            }
2606        }
2607        Ok(self.unified_forward_internal(items))
2608    }
2609
2610    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2611        LlamaFamilyModel::<B, KvFp16>::forward_verify(self, cache_id, tokens)
2612    }
2613
2614    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2615        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2616            for c in caches.iter_mut() {
2617                if new_len < c.len {
2618                    c.len = new_len;
2619                }
2620            }
2621        }
2622        let mut ctx = B::new_context();
2623        B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2624        self.graph_warmup = 0;
2625        self.graph_capture_failed = false;
2626    }
2627
2628    fn release(&mut self, cache_id: &str) {
2629        let mut ctx = B::new_context();
2630        B::sync(&mut ctx);
2631        self.graph_warmup = 0;
2632        self.graph_capture_failed = false;
2633        self.lora_cache_adapters.remove(cache_id);
2634        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2635            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2636                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2637                if let Some(c0) = caches.first() {
2638                    if !c0.paged_block_indices.is_empty() {
2639                        alloc.free(&c0.paged_block_indices);
2640                    }
2641                }
2642                for c in caches.iter_mut() {
2643                    c.paged_block_indices.clear();
2644                }
2645            }
2646            self.kv_free_pool.push(caches);
2647        }
2648    }
2649
2650    fn reset(&mut self) {
2651        let mut ctx = B::new_context();
2652        B::sync(&mut ctx);
2653        B::reset_all_graphs(&mut ctx);
2654        B::sync(&mut ctx);
2655        self.graph_warmup = 0;
2656        self.graph_capture_failed = false;
2657        self.batched_graph_keys_seen.clear();
2658        self.batched_graph_warmup = 0;
2659        self.batched_graph_failed = false;
2660        self.kv_caches.clear();
2661        self.kv_free_pool.clear();
2662        self.lora_cache_adapters.clear();
2663    }
2664}
2665
2666// INT8 DecoderOnlyLLM impl — minimal: no batched / unified-forward overrides
2667// (default trait impl falls back to per-item decode). PR D will add INT8
2668// batched paths once the kernels stabilize.
2669impl<B: MoeLlmBackend + BackendInt8KvOps> DecoderOnlyLLM for LlamaFamilyModel<B, KvInt8> {
2670    fn config(&self) -> &LlmRuntimeConfig {
2671        &self.runtime_cfg
2672    }
2673
2674    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
2675        Some(self.prefix_cache_snapshot_json())
2676    }
2677
2678    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
2679        Some(self.lora_metrics_snapshot_json())
2680    }
2681
2682    fn set_lora_adapter_for_cache(
2683        &mut self,
2684        cache_id: &str,
2685        adapter: Option<ActiveLoraAdapter>,
2686    ) -> std::result::Result<(), ferrum_types::FerrumError> {
2687        if let Some(adapter) = adapter {
2688            self.ensure_lora_adapter_loaded(adapter.clone())?;
2689            self.lora_cache_adapters
2690                .insert(cache_id.to_string(), adapter.name);
2691        } else {
2692            self.lora_cache_adapters.remove(cache_id);
2693        }
2694        Ok(())
2695    }
2696
2697    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2698        self.ensure_scratch(max_tokens);
2699        self.ensure_kv(cache_id);
2700
2701        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2702        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2703        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2704            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2705                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2706                if let Some(c0) = caches.first() {
2707                    if !c0.paged_block_indices.is_empty() {
2708                        alloc.free(&c0.paged_block_indices);
2709                    }
2710                }
2711                for c in caches.iter_mut() {
2712                    c.paged_block_indices.clear();
2713                }
2714            }
2715            self.kv_free_pool.push(caches);
2716        }
2717    }
2718
2719    fn kv_capacity(&self) -> usize {
2720        llama_family_runtime_env().kv_capacity_for_model(self.cfg.max_seq_len)
2721    }
2722
2723    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2724        self.prefill_internal(cache_id, tokens)
2725    }
2726
2727    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2728        self.decode_internal(cache_id, token, pos)
2729    }
2730
2731    // decode_batch + unified_forward use trait defaults (per-item fallback).
2732
2733    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2734        LlamaFamilyModel::<B, KvInt8>::forward_verify(self, cache_id, tokens)
2735    }
2736
2737    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2738        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2739            for c in caches.iter_mut() {
2740                if new_len < c.len {
2741                    c.len = new_len;
2742                }
2743            }
2744        }
2745        let mut ctx = B::new_context();
2746        B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2747        self.graph_warmup = 0;
2748        self.graph_capture_failed = false;
2749    }
2750
2751    fn release(&mut self, cache_id: &str) {
2752        let mut ctx = B::new_context();
2753        B::sync(&mut ctx);
2754        self.graph_warmup = 0;
2755        self.graph_capture_failed = false;
2756        self.lora_cache_adapters.remove(cache_id);
2757        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2758            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2759                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2760                if let Some(c0) = caches.first() {
2761                    if !c0.paged_block_indices.is_empty() {
2762                        alloc.free(&c0.paged_block_indices);
2763                    }
2764                }
2765                for c in caches.iter_mut() {
2766                    c.paged_block_indices.clear();
2767                }
2768            }
2769            self.kv_free_pool.push(caches);
2770        }
2771    }
2772
2773    fn reset(&mut self) {
2774        let mut ctx = B::new_context();
2775        B::sync(&mut ctx);
2776        B::reset_all_graphs(&mut ctx);
2777        B::sync(&mut ctx);
2778        self.graph_warmup = 0;
2779        self.graph_capture_failed = false;
2780        self.kv_caches.clear();
2781        self.kv_free_pool.clear();
2782        self.lora_cache_adapters.clear();
2783    }
2784}
2785
2786fn build_rope_cache<B: QuantLlmBackend + BackendMoeFused>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
2787    let hd = cfg.head_dim;
2788    let half = hd / 2;
2789    let max = cfg.max_seq_len;
2790    let mut cos = vec![0.0f32; max * half];
2791    let mut sin = vec![0.0f32; max * half];
2792    for pos in 0..max {
2793        for i in 0..half {
2794            let freq = rope_freq(cfg, i);
2795            let angle = pos as f64 * freq;
2796            cos[pos * half + i] = angle.cos() as f32;
2797            sin[pos * half + i] = angle.sin() as f32;
2798        }
2799    }
2800    RopeCache {
2801        cos: B::from_slice(&cos),
2802        sin: B::from_slice(&sin),
2803    }
2804}
2805
2806fn rope_freq(cfg: &LlamaFamilyConfig, pair_idx: usize) -> f64 {
2807    let base_freq = 1.0f64
2808        / cfg
2809            .rope_theta
2810            .powf((2 * pair_idx) as f64 / cfg.head_dim as f64);
2811    match &cfg.rope_scaling {
2812        Some(RopeScalingConfig::Llama3 {
2813            factor,
2814            low_freq_factor,
2815            high_freq_factor,
2816            original_max_position_embeddings,
2817        }) => scale_llama3_rope_freq(
2818            base_freq,
2819            *factor,
2820            *low_freq_factor,
2821            *high_freq_factor,
2822            *original_max_position_embeddings,
2823        ),
2824        None => base_freq,
2825    }
2826}
2827
2828fn scale_llama3_rope_freq(
2829    freq: f64,
2830    factor: f64,
2831    low_freq_factor: f64,
2832    high_freq_factor: f64,
2833    original_max_position_embeddings: f64,
2834) -> f64 {
2835    let wavelen = 2.0 * std::f64::consts::PI / freq;
2836    let low_freq_wavelen = original_max_position_embeddings / low_freq_factor;
2837    let high_freq_wavelen = original_max_position_embeddings / high_freq_factor;
2838    if wavelen < high_freq_wavelen {
2839        freq
2840    } else if wavelen > low_freq_wavelen {
2841        freq / factor
2842    } else {
2843        let smooth = (original_max_position_embeddings / wavelen - low_freq_factor)
2844            / (high_freq_factor - low_freq_factor);
2845        (1.0 - smooth) * freq / factor + smooth * freq
2846    }
2847}
2848
2849#[cfg(test)]
2850mod tests {
2851    use super::{LlamaFamilyRuntimeEnv, DEFAULT_KV_CAPACITY};
2852
2853    #[test]
2854    fn llama_family_runtime_env_parses_startup_knobs() {
2855        let env = LlamaFamilyRuntimeEnv::from_env_vars([
2856            ("FERRUM_KV_CAPACITY", "4096"),
2857            ("FERRUM_METAL_PAGED_KV", "0"),
2858            ("FERRUM_PAGED_MAX_SEQS", "64"),
2859            ("FERRUM_DECODE_OP_PROFILE", "0"),
2860            ("FERRUM_PREFILL_OP_PROFILE", ""),
2861            ("FERRUM_CUDA_GRAPH", ""),
2862            ("FERRUM_DECODE_LAYER_PROFILE", "false"),
2863        ]);
2864
2865        assert_eq!(env.kv_capacity, Some(4096));
2866        assert_eq!(env.metal_paged_kv, Some(false));
2867        assert_eq!(env.paged_max_seqs, 64);
2868        assert!(env.decode_op_profile);
2869        assert!(env.prefill_op_profile);
2870        assert!(env.cuda_graph);
2871        assert!(env.decode_layer_profile);
2872        assert_eq!(env.kv_capacity_for_model(2048), 2048);
2873    }
2874
2875    #[test]
2876    fn llama_family_runtime_env_uses_defaults_for_invalid_values() {
2877        let env = LlamaFamilyRuntimeEnv::from_env_vars([
2878            ("FERRUM_KV_CAPACITY", "bad"),
2879            ("FERRUM_PAGED_MAX_SEQS", "bad"),
2880            ("FERRUM_METAL_PAGED_KV", "1"),
2881        ]);
2882
2883        assert_eq!(env.kv_capacity, None);
2884        assert_eq!(env.metal_paged_kv, Some(true));
2885        assert_eq!(env.paged_max_seqs, 32);
2886        assert_eq!(
2887            env.kv_capacity_for_model(DEFAULT_KV_CAPACITY * 2),
2888            DEFAULT_KV_CAPACITY
2889        );
2890    }
2891}