Skip to main content

ferrum_models/models/
llama_family.rs

1//! Llama-family decoder model as explicit code.
2//!
3//! Covers all "standard Llama-style GQA + SwiGLU + RoPE" decoders:
4//!   - Llama / Llama-2 / Llama-3  (no QK-norm)
5//!   - Qwen2 / Qwen2.5            (no QK-norm, structurally Llama)
6//!   - Qwen3                      (QK-norm per head, larger rope_theta)
7//!   - Mistral                    (sliding-window attention — not yet
8//!                                 supported on the forward path; will
9//!                                 require an `AttnConfig.sliding_window`
10//!                                 field + shader support in Phase D)
11//!
12//! Variant differences are controlled by `LlamaFamilyConfig::has_qk_norm`
13//! and RoPE theta. Weight loading accepts both fused (`qkv_proj`,
14//! `gate_up_proj`) and split (`q_proj`+`k_proj`+`v_proj`,
15//! `gate_proj`+`up_proj`) projection layouts — the loader (e.g.
16//! `CandleVarBuilderLoader`) fuses split weights on load so model code
17//! sees a uniform `qkv_proj` / `gate_up_proj` Linear.
18
19use std::collections::HashMap;
20use std::ops::Range;
21use std::sync::atomic::AtomicU64;
22
23use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16, KvInt8};
24use ferrum_kernels::backend::{
25    Backend, BackendGraph, BackendInt8KvOps, BackendMoeFused, BackendPagedKv, BackendQuantGguf,
26    BackendQuantMarlin, KvCache, KvLayer, LlmBackend, MoeLlmBackend, QuantLlmBackend,
27    MAX_LAYERS_FOR_GRAPH,
28};
29
30/// Graph cache key for the single-item decode path (`decode_internal`).
31/// Distinct from any `m_padded`-based key used by the batched path.
32pub(crate) const SINGLE_ITEM_GRAPH_KEY: u64 = 0;
33
34/// Diag counters for the batched graph dispatcher (replay vs eager).
35pub(crate) static BATCHED_GRAPH_REPLAY_COUNT: AtomicU64 = AtomicU64::new(0);
36pub(crate) static BATCHED_GRAPH_EAGER_COUNT: AtomicU64 = AtomicU64::new(0);
37
38pub(crate) static ATTN_TIME_US: AtomicU64 = AtomicU64::new(0);
39pub(crate) static ATTN_CALLS: AtomicU64 = AtomicU64::new(0);
40pub(crate) static QKR_TIME_US: AtomicU64 = AtomicU64::new(0);
41pub(crate) static QKR_CALLS: AtomicU64 = AtomicU64::new(0);
42pub(crate) static MATMUL_TIME_US: AtomicU64 = AtomicU64::new(0);
43pub(crate) static MATMUL_CALLS: AtomicU64 = AtomicU64::new(0);
44pub(crate) static NORM_TIME_US: AtomicU64 = AtomicU64::new(0);
45pub(crate) static NORM_CALLS: AtomicU64 = AtomicU64::new(0);
46pub(crate) static OTHER_TIME_US: AtomicU64 = AtomicU64::new(0);
47pub(crate) static OTHER_CALLS: AtomicU64 = AtomicU64::new(0);
48use ferrum_quantization::{Linear, WeightLoader};
49use ferrum_types::Result;
50
51use crate::common::paged_pool::block_hash_chain;
52use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
53use crate::lora::{load_runtime_lora_adapter, ActiveLoraAdapter, RuntimeLoraAdapter};
54
55const DEFAULT_KV_CAPACITY: usize = 512;
56
57pub(crate) fn elapsed_micros_u64_floor1(t0: std::time::Instant) -> u64 {
58    t0.elapsed().as_micros().min(u64::MAX as u128).max(1) as u64
59}
60
61#[derive(Debug, Clone, Copy, Default)]
62pub(crate) struct LlamaStageHiddenBridgeTiming {
63    pub(crate) bridge_us: u64,
64    pub(crate) host_copy_us: u64,
65    pub(crate) device_copy_us: u64,
66}
67
68impl LlamaStageHiddenBridgeTiming {
69    pub(crate) fn add(self, other: Self) -> Self {
70        Self {
71            bridge_us: self.bridge_us.saturating_add(other.bridge_us),
72            host_copy_us: self.host_copy_us.saturating_add(other.host_copy_us),
73            device_copy_us: self.device_copy_us.saturating_add(other.device_copy_us),
74        }
75    }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub(crate) struct LlamaFamilyRuntimeEnv {
80    pub(crate) kv_capacity: Option<usize>,
81    pub(crate) metal_paged_kv: Option<bool>,
82    pub(crate) paged_max_seqs: usize,
83    pub(crate) decode_op_profile: bool,
84    pub(crate) prefill_op_profile: bool,
85    pub(crate) prefix_cache: bool,
86    pub(crate) cuda_graph: bool,
87    pub(crate) decode_layer_profile: bool,
88}
89
90impl LlamaFamilyRuntimeEnv {
91    /// Resolve from the process-wide snapshot installed at the composition root
92    /// (was a direct `std::env::vars()` read). The model holds this in a
93    /// `runtime_env` field, resolved once at construction.
94    fn from_env() -> Self {
95        Self::from_runtime_config_snapshot(&ferrum_types::active_runtime_snapshot())
96    }
97
98    fn from_runtime_config_snapshot(snapshot: &ferrum_types::RuntimeConfigSnapshot) -> Self {
99        Self::from_env_vars(
100            snapshot
101                .entries
102                .iter()
103                .map(|e| (e.key.as_str(), e.effective_value.as_str())),
104        )
105    }
106
107    fn from_env_vars<I, K, V>(vars: I) -> Self
108    where
109        I: IntoIterator<Item = (K, V)>,
110        K: AsRef<str>,
111        V: AsRef<str>,
112    {
113        let mut config = Self {
114            kv_capacity: None,
115            metal_paged_kv: None,
116            paged_max_seqs: 32,
117            decode_op_profile: false,
118            prefill_op_profile: false,
119            prefix_cache: false,
120            cuda_graph: false,
121            decode_layer_profile: false,
122        };
123        for (name, value) in vars {
124            let value = value.as_ref();
125            match name.as_ref() {
126                "FERRUM_KV_CAPACITY" => config.kv_capacity = value.parse::<usize>().ok(),
127                "FERRUM_METAL_PAGED_KV" => config.metal_paged_kv = Some(value != "0"),
128                "FERRUM_PAGED_MAX_SEQS" => {
129                    if let Ok(max_seqs) = value.parse::<usize>() {
130                        config.paged_max_seqs = max_seqs;
131                    }
132                }
133                "FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
134                "FERRUM_PREFILL_OP_PROFILE" => config.prefill_op_profile = true,
135                "FERRUM_PREFIX_CACHE" => config.prefix_cache = value == "1",
136                "FERRUM_CUDA_GRAPH" => config.cuda_graph = true,
137                "FERRUM_DECODE_LAYER_PROFILE" => config.decode_layer_profile = true,
138                _ => {}
139            }
140        }
141        config
142    }
143
144    fn kv_capacity_for_model(&self, model_max: usize) -> usize {
145        self.kv_capacity
146            .map(|cap| cap.min(model_max))
147            .unwrap_or_else(|| model_max.min(DEFAULT_KV_CAPACITY))
148    }
149
150    fn paged_kv_enabled<B: BackendPagedKv>(&self) -> bool {
151        self.metal_paged_kv
152            .unwrap_or_else(|| B::supports_paged_kv())
153    }
154}
155
156/// Whether decode-op profiling is enabled, resolved from the runtime snapshot
157/// installed at the composition root. Used by pipeline paths that operate on a
158/// `LlamaFamilyPipelineModel` and so don't hold a `LlamaFamilyModel`'s
159/// `runtime_env` field. Called once per decode-batch, not per op.
160pub(crate) fn llama_family_decode_op_profile_enabled() -> bool {
161    LlamaFamilyRuntimeEnv::from_env().decode_op_profile
162}
163
164#[derive(Clone, Debug, PartialEq)]
165pub enum RopeScalingConfig {
166    /// Meta Llama 3.1/3.2/3.3 long-context RoPE scaling.
167    Llama3 {
168        factor: f64,
169        low_freq_factor: f64,
170        high_freq_factor: f64,
171        original_max_position_embeddings: f64,
172    },
173}
174
175impl RopeScalingConfig {
176    pub fn llama3_default() -> Self {
177        Self::Llama3 {
178            factor: 8.0,
179            low_freq_factor: 1.0,
180            high_freq_factor: 4.0,
181            original_max_position_embeddings: 8192.0,
182        }
183    }
184}
185
186/// Full Qwen3 architecture config (everything the model code needs, not just
187/// the engine-facing subset in `LlmRuntimeConfig`).
188#[derive(Clone, Debug, PartialEq)]
189pub struct LlamaFamilyConfig {
190    pub hidden_size: usize,
191    pub intermediate_size: usize,
192    pub num_heads: usize,
193    pub num_kv_heads: usize,
194    pub head_dim: usize,
195    pub num_layers: usize,
196    pub vocab_size: usize,
197    pub max_seq_len: usize,
198    pub rms_norm_eps: f32,
199    pub rope_theta: f64,
200    pub rope_scaling: Option<RopeScalingConfig>,
201    /// GGUF LLaMA stores Q/K in the llama.cpp interleaved RoPE layout.
202    /// HF safetensors Qwen/LLaMA definitions keep the default half-split
203    /// GPT-NeoX layout used by existing Ferrum kernels.
204    pub rope_interleaved: bool,
205    /// Whether the checkpoint has `q_norm` / `k_norm` per layer. All known
206    /// Qwen3 checkpoints do; some derivatives may strip them.
207    pub has_qk_norm: bool,
208    /// Sliding-window attention size. `0` disables (full causal).
209    /// Mistral v0.1 sets 4096; Mistral v0.2+ removed the limitation (0).
210    pub sliding_window: usize,
211}
212
213impl LlamaFamilyConfig {
214    pub fn to_runtime(&self) -> LlmRuntimeConfig {
215        LlmRuntimeConfig {
216            hidden_size: self.hidden_size,
217            num_layers: self.num_layers,
218            num_kv_heads: self.num_kv_heads,
219            head_dim: self.head_dim,
220            vocab_size: self.vocab_size,
221            max_seq_len: self.max_seq_len,
222        }
223    }
224
225    /// Build config from a `ModelDefinition`, shared field extraction.
226    /// Variant-specific constructors below set `has_qk_norm` and fall back
227    /// to different `rope_theta` defaults when the checkpoint doesn't set one.
228    fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
229        let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
230        let head_dim = def
231            .extra_params
232            .get("head_dim")
233            .and_then(|v| v.as_u64())
234            .map(|v| v as usize)
235            .unwrap_or(def.hidden_size / def.num_attention_heads);
236        // Mistral / Gemma: "sliding_window" may be null (v0.2+) or a positive
237        // integer (v0.1). Non-null value passes through; missing/null → 0.
238        let sliding_window = def
239            .extra_params
240            .get("sliding_window")
241            .and_then(|v| v.as_u64())
242            .map(|v| v as usize)
243            .unwrap_or(0);
244
245        LlamaFamilyConfigBase {
246            hidden_size: def.hidden_size,
247            intermediate_size: def.intermediate_size,
248            num_heads: def.num_attention_heads,
249            num_kv_heads,
250            head_dim,
251            num_layers: def.num_hidden_layers,
252            vocab_size: def.vocab_size,
253            max_seq_len: def.max_position_embeddings,
254            rms_norm_eps: def.norm_eps as f32,
255            rope_theta_opt: def.rope_theta,
256            rope_scaling: rope_scaling_from_model_def(def),
257            sliding_window,
258        }
259    }
260
261    fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
262        Self {
263            hidden_size: b.hidden_size,
264            intermediate_size: b.intermediate_size,
265            num_heads: b.num_heads,
266            num_kv_heads: b.num_kv_heads,
267            head_dim: b.head_dim,
268            num_layers: b.num_layers,
269            vocab_size: b.vocab_size,
270            max_seq_len: b.max_seq_len,
271            rms_norm_eps: b.rms_norm_eps,
272            rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
273            rope_scaling: b.rope_scaling,
274            rope_interleaved: false,
275            has_qk_norm,
276            sliding_window: b.sliding_window,
277        }
278    }
279
280    /// Qwen3: QK-norm on, rope_theta default 1e6.
281    pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
282        Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
283    }
284
285    /// Llama / Llama-2 / Llama-3: no QK-norm; rope_theta varies by version
286    /// (10k for Llama-2, 500k for Llama-3.1+) — use the checkpoint value or
287    /// fall back to the most common modern value.
288    pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
289        Self::from_base(Self::from_def_base(def), 500_000.0, false)
290    }
291
292    /// Qwen2 / Qwen2.5: structurally Llama; no QK-norm; rope_theta default 1e6.
293    pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
294        Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
295    }
296
297    /// Mistral: no QK-norm; `rope_theta` commonly 10_000 (v0.1 / v0.2).
298    /// Picks up `sliding_window` from the checkpoint's config.json
299    /// (Mistral v0.1: 4096; Mistral v0.2+: 0 / null).
300    pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
301        Self::from_base(Self::from_def_base(def), 10_000.0, false)
302    }
303}
304
305struct LlamaFamilyConfigBase {
306    hidden_size: usize,
307    intermediate_size: usize,
308    num_heads: usize,
309    num_kv_heads: usize,
310    head_dim: usize,
311    num_layers: usize,
312    vocab_size: usize,
313    max_seq_len: usize,
314    rms_norm_eps: f32,
315    rope_theta_opt: Option<f64>,
316    rope_scaling: Option<RopeScalingConfig>,
317    sliding_window: usize,
318}
319
320fn rope_scaling_from_model_def(
321    def: &crate::definition::ModelDefinition,
322) -> Option<RopeScalingConfig> {
323    let value = def.extra_params.get("rope_scaling")?;
324    let obj = value.as_object()?;
325    let rope_type = obj
326        .get("rope_type")
327        .or_else(|| obj.get("type"))
328        .and_then(|v| v.as_str())?;
329    if rope_type != "llama3" {
330        return None;
331    }
332    let factor = json_f64(obj.get("factor"))?;
333    let low_freq_factor = json_f64(obj.get("low_freq_factor"))?;
334    let high_freq_factor = json_f64(obj.get("high_freq_factor"))?;
335    let original_max_position_embeddings = json_f64(obj.get("original_max_position_embeddings"))
336        .or_else(|| {
337            def.extra_params
338                .get("original_max_position_embeddings")
339                .and_then(|v| json_f64(Some(v)))
340        })
341        .unwrap_or(8192.0);
342    if factor <= 0.0
343        || low_freq_factor <= 0.0
344        || high_freq_factor <= low_freq_factor
345        || original_max_position_embeddings <= 0.0
346    {
347        return None;
348    }
349    Some(RopeScalingConfig::Llama3 {
350        factor,
351        low_freq_factor,
352        high_freq_factor,
353        original_max_position_embeddings,
354    })
355}
356
357fn json_f64(value: Option<&serde_json::Value>) -> Option<f64> {
358    match value? {
359        serde_json::Value::Number(n) => n.as_f64(),
360        _ => None,
361    }
362}
363
364/// Per-layer weights. `Box<dyn Linear<B>>` means each projection can be
365/// Dense / GPTQ / AWQ / GGUF without the surrounding code caring.
366pub struct LlamaFamilyLayer<B: QuantLlmBackend + BackendMoeFused> {
367    pub input_ln_w: B::Buffer,
368    pub qkv_proj: Box<dyn Linear<B>>,
369    /// QK-norm weight per head: `[head_dim]`. Optional for non-Qwen3 derivatives.
370    pub q_norm_w: Option<B::Buffer>,
371    pub k_norm_w: Option<B::Buffer>,
372    pub o_proj: Box<dyn Linear<B>>,
373    pub post_ln_w: B::Buffer,
374    pub gate_up_proj: Box<dyn Linear<B>>,
375    pub down_proj: Box<dyn Linear<B>>,
376}
377
378#[derive(Debug, Clone, PartialEq, Eq)]
379pub struct LlamaFamilyLayerStageConfig {
380    pub source_layers: Range<usize>,
381    pub load_embedding: bool,
382    pub load_lm_head: bool,
383}
384
385impl LlamaFamilyLayerStageConfig {
386    pub fn full(num_layers: usize) -> Self {
387        Self {
388            source_layers: 0..num_layers,
389            load_embedding: true,
390            load_lm_head: true,
391        }
392    }
393
394    pub fn backbone(num_layers: usize) -> Self {
395        Self {
396            source_layers: 0..num_layers,
397            load_embedding: false,
398            load_lm_head: false,
399        }
400    }
401
402    pub fn pipeline_stage(
403        source_layers: Range<usize>,
404        is_first_stage: bool,
405        is_last_stage: bool,
406    ) -> Self {
407        Self {
408            source_layers,
409            load_embedding: is_first_stage,
410            load_lm_head: is_last_stage,
411        }
412    }
413}
414
415fn load_llama_family_layers<B: MoeLlmBackend>(
416    cfg: &LlamaFamilyConfig,
417    loader: &dyn WeightLoader<B>,
418    source_layers: Range<usize>,
419) -> Result<Vec<LlamaFamilyLayer<B>>> {
420    if source_layers.start > source_layers.end || source_layers.end > cfg.num_layers {
421        return Err(ferrum_types::FerrumError::model(format!(
422            "llama layer range {}..{} is outside model layer count {}",
423            source_layers.start, source_layers.end, cfg.num_layers
424        )));
425    }
426
427    let mut layers = Vec::with_capacity(source_layers.end.saturating_sub(source_layers.start));
428    for li in source_layers {
429        let prefix = format!("model.layers.{li}");
430        let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
431        let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
432        let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
433        let post_ln_w = loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
434        let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
435        let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
436
437        let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
438            let q = loader
439                .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
440                .ok();
441            let k = loader
442                .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
443                .ok();
444            (q, k)
445        } else {
446            (None, None)
447        };
448
449        layers.push(LlamaFamilyLayer {
450            input_ln_w,
451            qkv_proj,
452            q_norm_w,
453            k_norm_w,
454            o_proj,
455            post_ln_w,
456            gate_up_proj,
457            down_proj,
458        });
459    }
460    Ok(layers)
461}
462
463fn load_llama_family_lm_head<B: MoeLlmBackend>(
464    cfg: &LlamaFamilyConfig,
465    loader: &dyn WeightLoader<B>,
466) -> Result<Box<dyn Linear<B>>> {
467    // LM head: either dedicated `lm_head.weight` or tied to embedding.
468    // Many models (Qwen3-4B, Llama-3.2-1B, some Qwen2.5) use TIED
469    // embeddings — lm_head shares weights with model.embed_tokens. When
470    // no dedicated lm_head tensor exists, re-load the embed tensor as a
471    // DenseLinear. This duplicates the buffer (memory cost = vocab*h*2
472    // bytes, e.g. ~770MB for Qwen3-4B) but keeps the Linear trait's
473    // owned-weights invariant. Sharing via Arc is a future optimisation.
474    let lm_head = if loader.has_tensor("lm_head.weight") {
475        loader.load_linear("lm_head")?
476    } else {
477        tracing::info!(
478            "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
479        );
480        let as_linear = loader.load_linear("model.embed_tokens")?;
481        // Sanity check: shape must be [vocab, hidden].
482        if as_linear.out_features() != cfg.vocab_size || as_linear.in_features() != cfg.hidden_size
483        {
484            return Err(ferrum_types::FerrumError::model(format!(
485                "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
486                as_linear.out_features(),
487                as_linear.in_features(),
488                cfg.vocab_size,
489                cfg.hidden_size
490            )));
491        }
492        as_linear
493    };
494    Ok(lm_head)
495}
496
497/// Precomputed RoPE cos/sin tables (shape `[max_seq, head_dim / 2]` each).
498pub struct RopeCache<B: QuantLlmBackend + BackendMoeFused> {
499    pub cos: B::Buffer,
500    pub sin: B::Buffer,
501}
502
503/// Reusable per-layer scratch buffers sized for `max_tokens` tokens of a
504/// single forward pass (prefill or decode step).
505///
506/// Sized lazily on first use so tiny decode steps don't pay for prefill-sized
507/// buffers. Grows monotonically when a larger prefill arrives.
508pub struct LlamaFamilyScratch<B: QuantLlmBackend + BackendMoeFused> {
509    /// Residual stream — wrapped in Option so decode_internal can
510    /// `.take()` it without needing an alloc placeholder.
511    ///
512    /// Why this matters for graph capture: the old pattern was
513    /// `mem::replace(&mut scratch.residual, B::alloc(1))` which creates a
514    /// 1-element buffer at every decode step. When graph capture is on,
515    /// that alloc-during-capture + drop-after-capture pair surfaces as
516    /// cuMemFreeAsync(INVALID_VALUE) because the free tries to release a
517    /// pointer the captured graph may still reference. Option::take leaves
518    /// None and moves the real buffer into a local — no spurious alloc.
519    pub residual: Option<B::Buffer>,
520    pub norm_out: B::Buffer,
521    pub qkv_out: B::Buffer,
522    // ── Per-item scratch for batched decode path ──────────────────────
523    // decode_batch_internal runs tokens=M batched ops for the GEMM-heavy
524    // half (norm, qkv_proj, split_qkv, o_proj, post_norm, gate_up, silu,
525    // down, residual_add) but must loop per-item for rope + KV append +
526    // attention (each item has its own KV cache at a different kv_len).
527    // These single-item buffers hold item i's slice during that loop.
528    /// Item-scope q_buf slice, sized `q_dim`.
529    pub q_single: B::Buffer,
530    pub k_single: B::Buffer,
531    pub v_single: B::Buffer,
532    pub q_head_major_single: B::Buffer,
533    pub k_head_major_single: B::Buffer,
534    pub v_head_major_single: B::Buffer,
535    pub attn_head_major_single: B::Buffer,
536    pub attn_flat_single: B::Buffer,
537    /// Batched logits output, sized `max_tokens * vocab_size`. Used only
538    /// in decode_batch; prefill/single-decode use the regular `logits`.
539    pub batch_logits: B::Buffer,
540    /// Token-major Q/K/V right after `split_qkv`. Stride: heads * hd per row.
541    pub q_buf: B::Buffer,
542    pub k_buf: B::Buffer,
543    pub v_buf: B::Buffer,
544    /// Head-major Q produced by `qk_norm_rope` — fed into `flash_attention`.
545    pub q_head_major: B::Buffer,
546    /// Head-major K/V staging — produced by `qk_norm_rope`, consumed by
547    /// `kv_cache_append_head_major` (no reuse after append).
548    pub k_head_major: B::Buffer,
549    pub v_head_major: B::Buffer,
550    /// Head-major attention output from `flash_attention`.
551    pub attn_head_major_out: B::Buffer,
552    /// Token-major attention output after `transpose_head_to_token`.
553    pub attn_flat: B::Buffer,
554    pub o_proj_out: B::Buffer,
555    pub gate_up_out: B::Buffer,
556    pub silu_out: B::Buffer,
557    pub mlp_out: B::Buffer,
558    /// Paged batched dispatch scratch (Phase 4b). Sized for
559    /// `FERRUM_PAGED_MAX_SEQS × q_dim` so multi-seq decode can fan
560    /// in M items' Q into a single buffer for one batched
561    /// `paged_decode_attention(num_seqs=M)` call. `None` when paged
562    /// mode is off.
563    pub paged_batch_q: Option<B::Buffer>,
564    pub paged_batch_o: Option<B::Buffer>,
565    /// Stacked per-seq block tables for batched paged dispatch.
566    /// Layout: `[max_M, max_blocks_per_seq]` u32. Written
567    /// host-side per decode_batch step.
568    pub paged_batch_block_tables: Option<B::Buffer>,
569    /// Stacked per-seq context lengths for batched paged dispatch
570    /// (`[max_M]` u32).
571    pub paged_batch_context_lens: Option<B::Buffer>,
572    /// `max_blocks_per_seq` value baked into the stacked block_tables
573    /// stride. Set when `paged_batch_block_tables` is allocated.
574    pub paged_max_blocks_per_seq: usize,
575    /// Engine-side max concurrent sequences (= `FERRUM_PAGED_MAX_SEQS`).
576    /// Pinned at the first `enable_paged_batch` so the unified-forward
577    /// scratch sizes (`unified_cu_seqlens_q`, `unified_pos_offsets`,
578    /// `unified_block_tables`, `unified_packed_*`) are big enough for
579    /// any subsequent batch up to that bound.
580    pub paged_max_seqs: usize,
581    /// Per-item RoPE positions for the batched-decode path (`[max_M]`
582    /// i32 / u32). Written host-side once per batched-decode step from
583    /// each request's `pos` field, read by the batched
584    /// `qk_norm_rope_batched_per_item` CUDA kernel.
585    pub batch_positions: B::Buffer,
586    /// Per-item input token ids for the batched-decode path (`[max_M]`
587    /// u32). Written once per call before forward; read by the
588    /// graph-capture-friendly `embedding_lookup_batched_dyn` variant.
589    pub batch_tokens: B::Buffer,
590    /// Per-item KV-cache length BEFORE this step's kv_append
591    /// (`[max_M]` u32). Used by `kv_cache_append_batched_per_cache_dyn`
592    /// to write at the right slot for graph replay.
593    pub batch_kv_lens_pre: B::Buffer,
594    /// Per-item KV-cache length AFTER this step's kv_append
595    /// (`[max_M]` u32 = pre + 1). Used by
596    /// `flash_attention_batched_per_cache_dyn` for the attention
597    /// reduce window length.
598    pub batch_kv_lens_post: B::Buffer,
599    /// Output buffers for the batched per-item qk_norm_rope kernel.
600    /// Same shape as q_buf / k_buf / v_buf — separate so the kernel
601    /// API can take `&input` and `&mut output` without aliasing.
602    pub q_normed_batched: B::Buffer,
603    pub k_normed_batched: B::Buffer,
604    pub v_normed_batched: B::Buffer,
605
606    // ── Unified mixed-batch scratch (chunked-prefill path; Step 5b) ─────
607    // Buffers sized for `M_total = sum(items[i].q_tokens.len())`. Grown
608    // on demand by `ensure_unified_scratch(M_total_max)`. Used only by
609    // the new `unified_forward_internal`; legacy `forward_layer_batched_decode`
610    // continues to use the per-item-stride scratch above.
611    pub unified_capacity: usize, // current allocated M_total slots
612    pub unified_residual: Option<B::Buffer>,
613    pub unified_norm_out: Option<B::Buffer>,
614    pub unified_qkv_out: Option<B::Buffer>,
615    pub unified_packed_q: Option<B::Buffer>,
616    pub unified_attn_out: Option<B::Buffer>,
617    pub unified_o_proj_out: Option<B::Buffer>,
618    pub unified_gate_up_out: Option<B::Buffer>,
619    pub unified_silu_out: Option<B::Buffer>,
620    pub unified_mlp_out: Option<B::Buffer>,
621    /// Per-item index buffers (i32-stored-as-f16): cu_seqlens_q is
622    /// length `max_seqs+1`, pos_offsets is `max_seqs`, block_tables is
623    /// `max_seqs * max_blocks_per_seq`. Sized once at first use to
624    /// match `paged_batch_*` capacity since they share `max_seqs`.
625    pub unified_cu_seqlens_q: Option<B::Buffer>,
626    pub unified_pos_offsets: Option<B::Buffer>,
627    pub unified_block_tables: Option<B::Buffer>,
628    /// Packed last-token hidden states for is_final_chunk items
629    /// (`[num_sampled, h]`). Used as input to lm_head.
630    pub unified_packed_normed: Option<B::Buffer>,
631    /// Packed logits output (`[num_sampled, vocab]`).
632    pub unified_packed_logits: Option<B::Buffer>,
633    /// Last token's hidden state (`[h]`). For prefill this is populated via
634    /// `copy_slice(residual, (seq_len-1)*h, ..)`; for decode `residual` already
635    /// holds only 1 row so `last_hidden` is unused on that path.
636    pub last_hidden: B::Buffer,
637    /// Final-norm output for the last token (`[h]`).
638    pub last_normed: B::Buffer,
639    /// lm_head logits (`[vocab]`).
640    pub logits: B::Buffer,
641    /// The max tokens-per-step this scratch has been sized for.
642    pub max_tokens: usize,
643}
644
645impl<B: QuantLlmBackend + BackendMoeFused> LlamaFamilyScratch<B> {
646    fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
647        let h = cfg.hidden_size;
648        let im = cfg.intermediate_size;
649        let q_dim = cfg.num_heads * cfg.head_dim;
650        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
651        let qkv_dim = q_dim + 2 * kv_dim;
652        let t = max_tokens;
653        Self {
654            residual: Some(B::alloc(t * h)),
655            norm_out: B::alloc(t * h),
656            qkv_out: B::alloc(t * qkv_dim),
657            q_buf: B::alloc(t * q_dim),
658            k_buf: B::alloc(t * kv_dim),
659            v_buf: B::alloc(t * kv_dim),
660            q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
661            k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
662            v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
663            attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
664            attn_flat: B::alloc(t * q_dim),
665            o_proj_out: B::alloc(t * h),
666            gate_up_out: B::alloc(t * 2 * im),
667            silu_out: B::alloc(t * im),
668            mlp_out: B::alloc(t * h),
669            last_hidden: B::alloc(h),
670            last_normed: B::alloc(h),
671            logits: B::alloc(cfg.vocab_size),
672            q_single: B::alloc(q_dim),
673            k_single: B::alloc(kv_dim),
674            v_single: B::alloc(kv_dim),
675            q_head_major_single: B::alloc(q_dim),
676            k_head_major_single: B::alloc(kv_dim),
677            v_head_major_single: B::alloc(kv_dim),
678            attn_head_major_single: B::alloc(q_dim),
679            attn_flat_single: B::alloc(q_dim),
680            batch_logits: B::alloc(t * cfg.vocab_size),
681            // Paged batched dispatch scratch. None until `enable_paged_batch`
682            // is called from `ensure_kv` once the model knows max_seqs +
683            // max_blocks_per_seq. This avoids paying the alloc cost when
684            // paged mode is off.
685            paged_batch_q: None,
686            paged_batch_o: None,
687            paged_batch_block_tables: None,
688            paged_batch_context_lens: None,
689            paged_max_blocks_per_seq: 0,
690            paged_max_seqs: 0,
691            batch_positions: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
692            batch_tokens: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
693            batch_kv_lens_pre: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
694            batch_kv_lens_post: B::alloc_typed(ferrum_kernels::backend::Dtype::U32, t.max(1)),
695            q_normed_batched: B::alloc(t * q_dim),
696            k_normed_batched: B::alloc(t * kv_dim),
697            v_normed_batched: B::alloc(t * kv_dim),
698            unified_capacity: 0,
699            unified_residual: None,
700            unified_norm_out: None,
701            unified_qkv_out: None,
702            unified_packed_q: None,
703            unified_attn_out: None,
704            unified_o_proj_out: None,
705            unified_gate_up_out: None,
706            unified_silu_out: None,
707            unified_mlp_out: None,
708            unified_cu_seqlens_q: None,
709            unified_pos_offsets: None,
710            unified_block_tables: None,
711            unified_packed_normed: None,
712            unified_packed_logits: None,
713            max_tokens: t,
714        }
715    }
716
717    /// Grow unified-path scratch buffers to accommodate `m_total` query
718    /// tokens. Called lazily from `unified_forward_internal` so single-
719    /// path workloads (no chunked prefill) don't pay the alloc cost.
720    pub(crate) fn ensure_unified_scratch(
721        &mut self,
722        cfg: &LlamaFamilyConfig,
723        m_total: usize,
724        max_seqs: usize,
725        max_blocks_per_seq: usize,
726    ) {
727        if m_total <= self.unified_capacity
728            && self.unified_residual.is_some()
729            && self.unified_cu_seqlens_q.is_some()
730        {
731            return;
732        }
733        let cap = m_total.max(self.unified_capacity).max(1);
734        let h = cfg.hidden_size;
735        let q_dim = cfg.num_heads * cfg.head_dim;
736        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
737        let qkv_dim = q_dim + 2 * kv_dim;
738        let im = cfg.intermediate_size;
739        let v = cfg.vocab_size;
740        self.unified_residual = Some(B::alloc(cap * h));
741        self.unified_norm_out = Some(B::alloc(cap * h));
742        self.unified_qkv_out = Some(B::alloc(cap * qkv_dim));
743        self.unified_packed_q = Some(B::alloc(cap * q_dim));
744        self.unified_attn_out = Some(B::alloc(cap * q_dim));
745        self.unified_o_proj_out = Some(B::alloc(cap * h));
746        self.unified_gate_up_out = Some(B::alloc(cap * 2 * im));
747        self.unified_silu_out = Some(B::alloc(cap * im));
748        self.unified_mlp_out = Some(B::alloc(cap * h));
749        if self.unified_cu_seqlens_q.is_none() {
750            self.unified_cu_seqlens_q = Some(B::alloc_typed(
751                ferrum_kernels::backend::Dtype::U32,
752                max_seqs + 1,
753            ));
754            self.unified_pos_offsets = Some(B::alloc_typed(
755                ferrum_kernels::backend::Dtype::U32,
756                max_seqs,
757            ));
758            self.unified_block_tables = Some(B::alloc_typed(
759                ferrum_kernels::backend::Dtype::U32,
760                max_seqs * max_blocks_per_seq,
761            ));
762            self.unified_packed_normed = Some(B::alloc(max_seqs * h));
763            self.unified_packed_logits = Some(B::alloc(max_seqs * v));
764        }
765        self.unified_capacity = cap;
766    }
767
768    /// Allocate scratch for batched paged dispatch (Phase 4b). Called
769    /// lazily from `ensure_kv` once paged mode is enabled and we know
770    /// the pool dimensions. Idempotent.
771    fn enable_paged_batch(
772        &mut self,
773        cfg: &LlamaFamilyConfig,
774        max_seqs: usize,
775        max_blocks_per_seq: usize,
776    ) {
777        if self.paged_batch_q.is_some() {
778            return;
779        }
780        let q_dim = cfg.num_heads * cfg.head_dim;
781        self.paged_batch_q = Some(B::alloc(max_seqs * q_dim));
782        self.paged_batch_o = Some(B::alloc(max_seqs * q_dim));
783        self.paged_batch_block_tables = Some(B::alloc_typed(
784            ferrum_kernels::backend::Dtype::U32,
785            max_seqs * max_blocks_per_seq,
786        ));
787        self.paged_batch_context_lens = Some(B::alloc_typed(
788            ferrum_kernels::backend::Dtype::U32,
789            max_seqs,
790        ));
791        self.paged_max_blocks_per_seq = max_blocks_per_seq;
792        self.paged_max_seqs = max_seqs;
793    }
794}
795
796/// Qwen3 model — decoder-only LLM, one per (backend, weights) combination.
797///
798/// Holds all parameters, scratch space, RoPE cache, and per-sequence KV caches.
799///
800/// `B: BackendGraph + BackendQuantMarlin + BackendQuantGguf` because the decode hot path uses CUDA Graph capture/replay
801/// when the backend supports it; non-graph backends (Metal/CPU) inherit no-op
802/// defaults, so this bound is satisfied by every concrete `Backend`.
803///
804/// `K: KvDtypeKind = KvFp16` (Dim 5): selects the KV cache element type.
805/// `K = KvFp16` constructs only `LayerKvCache::Fp16(...)` variants;
806/// `K = KvInt8` constructs only `LayerKvCache::Int8(...)` variants. The
807/// `B: BackendKvDtype<KvInt8>` bound is structurally required by the
808/// enum and is satisfied by all backends via stub impls (Cpu/Metal) or
809/// the real impl (CUDA).
810pub struct LlamaFamilyModel<B: MoeLlmBackend, K: KvLayer<B> = KvFp16> {
811    pub cfg: LlamaFamilyConfig,
812    pub runtime_cfg: LlmRuntimeConfig,
813
814    /// Backend capabilities resolved once at construction (the GOAL's allowed
815    /// "构造期 capability 决策一次"). Decode/prefill hot paths branch on these
816    /// fields instead of calling `B::supports_*()` inline, so a capability gate
817    /// can't silently rot in the forward code.
818    pub(crate) supports_varlen_qkv: bool,
819    pub(crate) supports_batched_decode: bool,
820    /// Runtime knobs resolved once at construction from the snapshot installed
821    /// at the composition root (was a process-wide OnceLock reading env).
822    /// Forward/decode read these fields instead of a global accessor.
823    pub(crate) runtime_env: LlamaFamilyRuntimeEnv,
824    pub(crate) batched_cfg: super::llama_family_forward_batched::LlamaBatchedRuntimeConfig,
825
826    /// Token embedding table. `None` for backbone-only models (e.g. the
827    /// Qwen3-TTS Talker, which embeds inputs externally and feeds via
828    /// `prefill_from_embeds`).
829    pub embed: Option<B::Buffer>,
830    /// Source checkpoint layer range loaded into `layers`. Full LLM models
831    /// use `0..cfg.num_layers`; layer-split stages load a contiguous subset.
832    pub layer_source_start: usize,
833    pub layer_source_end: usize,
834    pub layers: Vec<LlamaFamilyLayer<B>>,
835    pub final_norm_w: B::Buffer,
836    /// LM output head. `None` for backbone-only models.
837    pub lm_head: Option<Box<dyn Linear<B>>>,
838
839    pub rope: RopeCache<B>,
840    pub scratch: LlamaFamilyScratch<B>,
841
842    /// Per-sequence KV caches, one `Vec<KvCache<B>>` of length `num_layers`.
843    ///
844    /// Two layouts overlay this same map:
845    /// - **Contiguous mode** (default): each cache holds its own
846    ///   `[num_kv_heads, capacity, head_dim]` k/v buffers.
847    /// - **Paged mode** (`FERRUM_METAL_PAGED_KV=1`): k/v are unused
848    ///   placeholders; the real K/V live in [`Self::paged_pools`] and
849    ///   the cache's `block_table` + `context_lens` index into them.
850    pub kv_caches: HashMap<String, Vec<K::Layer>>,
851    /// Free pool of pre-allocated KV cache slots. Released caches return
852    /// here instead of being dropped, so their device pointers stay valid
853    /// across requests — critical for graph capture (pointers baked into
854    /// the captured graph would otherwise dangle).
855    kv_free_pool: Vec<Vec<K::Layer>>,
856
857    // ── Paged-KV multi-seq state (Phase 4) ─────────────────────────────
858    //
859    // Only populated when `FERRUM_METAL_PAGED_KV=1`. When set, every
860    // `kv_caches` entry becomes a "view" into the shared pool: its
861    // `k` / `v` buffers are placeholders; reads / writes go through
862    // `paged_pools[layer].k` / `.v` indexed via the cache's
863    // `block_table`. Multiple cache_ids share the same pool, with
864    // disjoint physical block sets owned by `paged_block_alloc`.
865    //
866    /// Shared K/V pools, one per layer. Sized at model load time for the
867    /// configured `MAX_RUNNING_REQUESTS × max_blocks_per_seq` blocks.
868    pub paged_pools: Option<Vec<(B::Buffer, B::Buffer)>>,
869    /// Block allocator hands out physical block indices from the pool.
870    /// `Mutex` because the engine can call `ensure_kv` / `release_kv`
871    /// from multiple threads in concurrent serving.
872    pub paged_block_alloc: Option<std::sync::Mutex<crate::common::paged_pool::BlockAllocator>>,
873    /// Paged-batch dispatch dimensions `(max_seqs, max_blocks_per_seq)`,
874    /// pinned at the first `ensure_kv` when paged-KV is on. Stored on
875    /// the model (not on scratch) so `ensure_scratch`'s realloc can
876    /// re-call `enable_paged_batch` after wiping scratch's
877    /// `paged_batch_block_tables` / `paged_batch_q` etc.
878    pub paged_dims: Option<(usize, usize)>,
879
880    // ── Graph capture state (CUDA only; harmless no-op on other backends) ──
881    /// Count of eager decode steps run so far. After `GRAPH_WARMUP`, the
882    /// next step captures the decode flow as a graph.
883    pub(crate) graph_warmup: usize,
884    /// True if capture was attempted but failed (e.g. backend doesn't
885    /// support graph capture). Stops further attempts, falls back to eager.
886    pub(crate) graph_capture_failed: bool,
887    /// Same warmup counter for the batched-decode path.
888    pub(crate) batched_graph_warmup: usize,
889    /// True if batched graph capture failed.
890    pub(crate) batched_graph_failed: bool,
891    /// Set of `m_padded` values (as u64 graph keys) for which a batched
892    /// graph has been captured. Multi-slot via cuda.rs's HashMap-keyed
893    /// graph cache — different batch shapes don't thrash a single slot.
894    pub(crate) batched_graph_keys_seen: std::collections::HashSet<u64>,
895    /// Cache IDs for which device-pointer scratch is currently populated.
896    /// Populate only re-runs when the batch composition changes (new
897    /// requests joined / requests finished). Hot-path optimization:
898    /// avoids 3 sync cuMemcpyHtoD_v2's per decode token (~5% TPOT).
899    pub(crate) batched_pointers_for: Option<Vec<String>>,
900    /// CUDA-graph state for the unified_forward path. Mirrors the
901    /// `batched_graph_*` triple but keyed on `(m_total, num_seqs)`
902    /// so different concurrency levels each get their own cached
903    /// graph instead of thrashing a single slot.
904    pub(crate) unified_graph_warmup: usize,
905    pub(crate) unified_graph_failed: bool,
906    pub(crate) unified_graph_keys_seen: std::collections::HashSet<u64>,
907
908    // ── Real paged-KV prefix cache counters ─────────────────────────────
909    prefix_cache_hits: u64,
910    prefix_cache_misses: u64,
911    prefix_cache_saved_prefill_tokens: u64,
912
913    // ── Startup LoRA runtime state ──────────────────────────────────────
914    lora_adapters: HashMap<String, RuntimeLoraAdapter<B>>,
915    lora_cache_adapters: HashMap<String, String>,
916    lora_projection_applications: u64,
917}
918
919impl<B: MoeLlmBackend, K: KvLayer<B>> LlamaFamilyModel<B, K> {
920    /// Build a Qwen3 model from weights provided by the loader.
921    ///
922    /// The loader decides per-projection whether to instantiate DenseLinear,
923    /// GptqLinear, etc. — this code doesn't care.
924    pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
925        let num_layers = cfg.num_layers;
926        Self::new_with_stage_config(cfg, loader, LlamaFamilyLayerStageConfig::full(num_layers))
927    }
928
929    /// Build a backbone-only Qwen3 transformer stack (no embed, no lm_head).
930    ///
931    /// Intended for composing the transformer inside a larger model where
932    /// embedding and output-head logic differs from the standard LLM path —
933    /// e.g. Qwen3-TTS Talker uses dual text/codec embeddings with a projection
934    /// MLP, and a codec_head output. The caller drives forward via
935    /// `prefill_from_embeds` / `decode_from_embed`.
936    ///
937    /// Loader must provide: per-layer weights under `model.layers.{i}.*` and
938    /// the final `model.norm.weight`. `model.embed_tokens` and `lm_head`
939    /// are NOT read.
940    pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
941        let num_layers = cfg.num_layers;
942        Self::new_with_stage_config(
943            cfg,
944            loader,
945            LlamaFamilyLayerStageConfig::backbone(num_layers),
946        )
947    }
948
949    /// Build a layer-split pipeline stage. The stage always loads its local
950    /// transformer layers and final norm; embedding and lm_head are loaded only
951    /// for the first and last stages respectively.
952    pub fn new_layer_stage(
953        cfg: LlamaFamilyConfig,
954        loader: &dyn WeightLoader<B>,
955        stage: LlamaFamilyLayerStageConfig,
956    ) -> Result<Self> {
957        Self::new_with_stage_config(cfg, loader, stage)
958    }
959
960    fn new_with_stage_config(
961        cfg: LlamaFamilyConfig,
962        loader: &dyn WeightLoader<B>,
963        stage: LlamaFamilyLayerStageConfig,
964    ) -> Result<Self> {
965        if stage.source_layers.is_empty() {
966            return Err(ferrum_types::FerrumError::model(
967                "llama layer stage must include at least one source layer",
968            ));
969        }
970        if stage.source_layers.end > cfg.num_layers {
971            return Err(ferrum_types::FerrumError::model(format!(
972                "llama layer stage range {}..{} is outside model layer count {}",
973                stage.source_layers.start, stage.source_layers.end, cfg.num_layers
974            )));
975        }
976
977        // Invalidate any graph from a previously-loaded model. The captured
978        // graph references the old model's scratch buffers; a fresh model
979        // gets fresh scratch, so reusing the graph would read/write freed
980        // pointers. Matters for test suites where multiple models coexist.
981        {
982            let mut ctx = B::new_context();
983            B::reset_all_graphs(&mut ctx);
984        }
985        let rope = build_rope_cache::<B>(&cfg);
986        let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
987        let embed = if stage.load_embedding {
988            Some(loader.load_tensor("model.embed_tokens.weight")?)
989        } else {
990            None
991        };
992        let layers = load_llama_family_layers(&cfg, loader, stage.source_layers.clone())?;
993
994        let lm_head = if stage.load_lm_head {
995            Some(load_llama_family_lm_head(&cfg, loader)?)
996        } else {
997            None
998        };
999        let final_norm_w = loader.load_tensor("model.norm.weight")?;
1000
1001        let layer_source_start = stage.source_layers.start;
1002        let layer_source_end = stage.source_layers.end;
1003        let runtime_cfg = cfg.to_runtime();
1004        // Construction-time capability resolution (GOAL-allowed): read once here
1005        // so the forward hot paths read self.* instead of B::supports_*().
1006        let supports_varlen_qkv = B::supports_varlen_qkv();
1007        let supports_batched_decode = B::supports_llama_family_batched_decode();
1008        let runtime_env = LlamaFamilyRuntimeEnv::from_env();
1009        let batched_cfg =
1010            super::llama_family_forward_batched::LlamaBatchedRuntimeConfig::from_env();
1011        Ok(Self {
1012            cfg,
1013            runtime_cfg,
1014            supports_varlen_qkv,
1015            supports_batched_decode,
1016            runtime_env,
1017            batched_cfg,
1018            embed,
1019            layer_source_start,
1020            layer_source_end,
1021            layers,
1022            final_norm_w,
1023            lm_head,
1024            rope,
1025            scratch,
1026            kv_caches: HashMap::new(),
1027            kv_free_pool: Vec::new(),
1028            paged_pools: None,
1029            paged_block_alloc: None,
1030            paged_dims: None,
1031            graph_warmup: 0,
1032            graph_capture_failed: false,
1033            batched_graph_warmup: 0,
1034            batched_graph_failed: false,
1035            batched_graph_keys_seen: std::collections::HashSet::new(),
1036            batched_pointers_for: None,
1037            unified_graph_warmup: 0,
1038            unified_graph_failed: false,
1039            unified_graph_keys_seen: std::collections::HashSet::new(),
1040            prefix_cache_hits: 0,
1041            prefix_cache_misses: 0,
1042            prefix_cache_saved_prefill_tokens: 0,
1043            lora_adapters: HashMap::new(),
1044            lora_cache_adapters: HashMap::new(),
1045            lora_projection_applications: 0,
1046        })
1047    }
1048
1049    pub fn source_layer_range(&self) -> Range<usize> {
1050        self.layer_source_start..self.layer_source_end
1051    }
1052
1053    pub fn local_layer_count(&self) -> usize {
1054        self.layers.len()
1055    }
1056
1057    pub fn cache_len(&self, cache_id: &str) -> usize {
1058        self.kv_caches
1059            .get(cache_id)
1060            .and_then(|layers| layers.first())
1061            .map(K::len)
1062            .unwrap_or(0)
1063    }
1064
1065    fn local_layer_indices(&self) -> Range<usize> {
1066        0..self.local_layer_count()
1067    }
1068
1069    fn source_layer_index(&self, local_layer_index: usize) -> usize {
1070        self.layer_source_start + local_layer_index
1071    }
1072
1073    fn local_layer_index_for_source(&self, source_layer_index: usize) -> Option<usize> {
1074        if source_layer_index < self.layer_source_start
1075            || source_layer_index >= self.layer_source_end
1076        {
1077            return None;
1078        }
1079        Some(source_layer_index - self.layer_source_start)
1080    }
1081
1082    /// Grow scratch buffers if `tokens` exceeds the current sizing.
1083    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
1084        if self.scratch.max_tokens < tokens {
1085            // Any captured decode graph holds pointers to the old scratch
1086            // buffers; those are about to be freed. Invalidate ALL captured
1087            // graphs (both single-item and per-m_padded batched) — every
1088            // captured kernel-arg pointer into scratch is stale.
1089            {
1090                let mut ctx = B::new_context();
1091                B::reset_all_graphs(&mut ctx);
1092            }
1093            self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
1094            self.graph_warmup = 0;
1095            self.graph_capture_failed = false;
1096            self.batched_graph_keys_seen.clear();
1097            self.batched_graph_warmup = 0;
1098            self.batched_graph_failed = false;
1099            self.unified_graph_keys_seen.clear();
1100            self.unified_graph_warmup = 0;
1101            self.unified_graph_failed = false;
1102            // Realloc wiped paged_batch_*. Re-enable using the dims
1103            // pinned at first ensure_kv. Without this, the next
1104            // `forward_layer_batched_decode` panics on
1105            // `paged_batch_block_tables missing`.
1106            if let Some((max_seqs, max_blocks_per_seq)) = self.paged_dims {
1107                self.scratch
1108                    .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
1109            }
1110        }
1111    }
1112
1113    /// Ensure per-layer KV caches exist for `cache_id`, pre-allocated to
1114    /// `max_seq_len` slots per head. Enables the in-place
1115    /// `kv_cache_append_head_major` path — no realloc per layer.
1116    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
1117        if self.kv_caches.contains_key(cache_id) {
1118            return;
1119        }
1120        let nkv = self.cfg.num_kv_heads;
1121        let hd = self.cfg.head_dim;
1122        // KV capacity defaults to a chat-friendly 4096 to keep the working
1123        // set sane on a 32 GB Mac (a 48-layer / 4-kv-head / 128-head_dim
1124        // model spends ~786 MB on 4096 KV slots, vs ~6 GB on the model's
1125        // declared 32K which would push the 17 GB Qwen3-30B-A3B model into
1126        // swap). `FERRUM_KV_CAPACITY=N` overrides; clamp to the model's
1127        // declared max so we never lie to the model about its window.
1128        let model_max = self.cfg.max_seq_len;
1129        // 512 in 0.7.2 — matches the value used in
1130        // docs/bench/macos-2026-05-02 to get the published numbers.
1131        // pre-0.7.2 default of 4096 was safe only because paged-KV was
1132        // opt-in (pool wasn't allocated). With paged-KV now on by
1133        // default + MAX_SEQS=32, the pool occupies physical memory:
1134        // ~3 GB on Qwen3-30B-A3B Q4_K_M leaves 18 GB weights + 3 GB pool
1135        // = 21 GB, fits comfortably on a 32 GB Mac. Long-context users
1136        // can `FERRUM_KV_CAPACITY=4096` and accept lower max_seqs.
1137        let runtime_env = &self.runtime_env;
1138        let max = runtime_env.kv_capacity_for_model(model_max);
1139
1140        // Paged-KV mode: `FERRUM_METAL_PAGED_KV=1` switches every cache
1141        // for this model into block-table-indirect layout. Kernels from
1142        // PR #68 (decode read) + PR #69 (decode write) handle the
1143        // indirect addressing; the LlamaFamily decode path below picks
1144        // them up automatically by checking `cache.block_size > 0`.
1145        //
1146        // Pool sizing: round capacity up to a multiple of block_size,
1147        // identity-assign logical→physical block. Memory footprint is
1148        // the same as contiguous (within block_size rounding); the
1149        // benefit only shows up under multi-seq sharing in Phase 4+.
1150        // Default ON when the backend supports paged-KV (Metal). Users
1151        // can force off with `FERRUM_METAL_PAGED_KV=0`. The flag was
1152        // opt-in pre-0.7.2; flipping the default so default `ferrum
1153        // serve` matches the bench-quality numbers without requiring
1154        // env-var knowledge.
1155        let paged = runtime_env.paged_kv_enabled::<B>();
1156        const PAGED_BLOCK_SIZE: usize = 16;
1157
1158        // Phase 4 shared-pool sizing. The pool sees ALL concurrent
1159        // sequences; per-cache_id state just owns indices into it.
1160        // Default 32: covers c=16 burst with 2× headroom for the
1161        // fresh-cache-id-per-request pattern that bench/server harnesses
1162        // use. Pool memory is `max_seqs × max_blocks_per_seq` total
1163        // blocks — we lowered DEFAULT_KV_CAPACITY to 2048 so this 2× max_seqs
1164        // bump keeps the pool footprint identical to the pre-0.7.2 default.
1165        let max_seqs = runtime_env.paged_max_seqs;
1166        let max_blocks_per_seq = max.div_ceil(PAGED_BLOCK_SIZE);
1167        let total_pool_blocks = max_seqs * max_blocks_per_seq;
1168
1169        // Lazy-allocate the shared paged pools on the FIRST paged
1170        // ensure_kv call. Pools are big — for Llama-8B (8 kv_heads,
1171        // head_dim=128) at 16 seqs × 256 blocks × 16 slots = 65536 KV
1172        // slots: 65536 * 8 * 128 * 4 = 256 MB per layer × 32 layers
1173        // = 8 GB total. Sized this large only because `max_seqs=16`
1174        // is the default; lower it via env to shrink.
1175        if paged && self.paged_pools.is_none() {
1176            let mut pools = Vec::with_capacity(self.local_layer_count());
1177            for _ in self.local_layer_indices() {
1178                let pool_floats = total_pool_blocks * nkv * PAGED_BLOCK_SIZE * hd;
1179                pools.push((B::alloc(pool_floats), B::alloc(pool_floats)));
1180            }
1181            self.paged_pools = Some(pools);
1182            self.paged_block_alloc = Some(std::sync::Mutex::new(
1183                crate::common::paged_pool::BlockAllocator::new(total_pool_blocks as u32),
1184            ));
1185        }
1186        // Phase 4b: ensure batched-dispatch scratch is allocated whenever
1187        // paged is on. Idempotent — re-init is a no-op if already
1188        // sized. Has to live outside the `paged_pools.is_none()` branch
1189        // because `ensure_scratch` may have replaced the scratch struct
1190        // since the pools were first allocated.
1191        if paged {
1192            self.scratch
1193                .enable_paged_batch(&self.cfg, max_seqs, max_blocks_per_seq);
1194            // Pin dims on the model so `ensure_scratch`'s realloc can
1195            // re-init paged_batch_* after wiping scratch.
1196            self.paged_dims = Some((max_seqs, max_blocks_per_seq));
1197        }
1198
1199        // Try pool first — reused buffers have stable device pointers,
1200        // so a captured decode graph can be replayed for this request too.
1201        // K::NAME selects which `LayerKvCache` variant to construct:
1202        // K-aware allocation: K::alloc_paged / K::alloc_contig pick the
1203        // right cache layout (FP16 → KvCache, INT8 → KvCacheQuant) per the
1204        // model's K marker. INT8 supports paged mode only — KvInt8::alloc_contig
1205        // panics, surfacing the misconfiguration here at first ensure_kv.
1206        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
1207            self.local_layer_indices()
1208                .map(|_| {
1209                    if paged {
1210                        K::alloc_paged(max_blocks_per_seq, PAGED_BLOCK_SIZE, nkv, hd)
1211                    } else {
1212                        K::alloc_contig(max, nkv, hd)
1213                    }
1214                })
1215                .collect()
1216        });
1217
1218        // Allocate physical blocks for THIS cache_id from the shared
1219        // pool. We allocate all `max_blocks_per_seq` upfront for
1220        // simplicity (matches contig's "pre-alloc to capacity"
1221        // semantics); a smarter Phase 4b can grow on demand to save
1222        // pool occupancy.
1223        if paged {
1224            let alloc_arc = self
1225                .paged_block_alloc
1226                .as_ref()
1227                .expect("paged_block_alloc must be initialised when paged=true");
1228            // Recover from a previously-poisoned mutex instead of panicking
1229            // (poison just means a prior holder panicked; the BlockAllocator
1230            // state is still intact since allocate_n is fail-safe).
1231            let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1232            let block_indices = match alloc.allocate_n(max_blocks_per_seq) {
1233                Ok(idx) => idx,
1234                Err(e) => {
1235                    // Pool exhaustion is a back-pressure signal, not a crash.
1236                    // Drop the lock, return the cache to the free pool, and
1237                    // bail before inserting it into kv_caches. The downstream
1238                    // call will then fail with a clean per-request error
1239                    // ("ensure_kv must be called before ...") instead of
1240                    // dragging every other in-flight request down with it.
1241                    drop(alloc);
1242                    self.kv_free_pool.push(caches);
1243                    eprintln!(
1244                        "[ferrum] paged KV pool exhausted on ensure_kv for \
1245                         cache_id={cache_id:?}: {e}. Increase \
1246                         FERRUM_PAGED_MAX_SEQS (currently {max_seqs}) or \
1247                         throttle concurrent requests.",
1248                    );
1249                    return;
1250                }
1251            };
1252            // Write the block table to each layer's cache. All layers
1253            // share the same logical→physical mapping for this seq.
1254            // Also stash the host-side index list so release_kv can
1255            // return them to the allocator without a device readback.
1256            let mut padded = block_indices.clone();
1257            padded.resize(max_blocks_per_seq, 0);
1258            let mut ctx_tmp = B::new_context();
1259            for c in caches.iter_mut() {
1260                if let Some(bt) = K::block_table_mut(c) {
1261                    B::write_typed::<u32>(&mut ctx_tmp, bt, &padded);
1262                }
1263                *K::paged_block_indices_mut(c) = block_indices.clone();
1264            }
1265            B::sync(&mut ctx_tmp);
1266        }
1267
1268        // Reset logical length; buffers stay. No need to zero the memory —
1269        // the kv_cache_append writes new K/V in place, and attention only
1270        // reads up to `cache_len`.
1271        for c in caches.iter_mut() {
1272            K::set_len(c, 0);
1273            if let Some(cl) = K::context_lens_mut(c) {
1274                let mut ctx_tmp = B::new_context();
1275                B::write_typed::<u32>(&mut ctx_tmp, cl, &[0u32]);
1276                B::sync(&mut ctx_tmp);
1277            }
1278        }
1279        self.kv_caches.insert(cache_id.to_string(), caches);
1280    }
1281
1282    fn record_prefix_cache_probe(&mut self, saved_tokens: usize) {
1283        if saved_tokens > 0 {
1284            self.prefix_cache_hits += 1;
1285            self.prefix_cache_saved_prefill_tokens += saved_tokens as u64;
1286        } else {
1287            self.prefix_cache_misses += 1;
1288        }
1289    }
1290
1291    fn try_acquire_prefix_cache(&mut self, cache_id: &str, tokens: &[u32]) -> usize {
1292        let Some(alloc_arc) = self.paged_block_alloc.as_ref() else {
1293            return 0;
1294        };
1295        let caches = match self.kv_caches.get(cache_id) {
1296            Some(caches) => caches,
1297            None => return 0,
1298        };
1299        let block_size = caches.first().map(K::block_size).unwrap_or(0);
1300        if block_size == 0 {
1301            return 0;
1302        }
1303
1304        let token_ids: Vec<ferrum_types::TokenId> = tokens
1305            .iter()
1306            .map(|&token| ferrum_types::TokenId::new(token))
1307            .collect();
1308        let hashes = block_hash_chain(&token_ids, block_size);
1309        if hashes.is_empty() {
1310            return 0;
1311        }
1312
1313        let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1314        let mut matched = Vec::with_capacity(hashes.len());
1315        for hash in hashes {
1316            match alloc.try_acquire_by_hash(hash) {
1317                Some(block) => matched.push(block),
1318                None => break,
1319            }
1320        }
1321        if matched.is_empty() {
1322            return 0;
1323        }
1324        let n_matched = matched.len();
1325
1326        let displaced = caches
1327            .first()
1328            .map(|cache| K::paged_block_indices(cache)[..n_matched].to_vec())
1329            .unwrap_or_default();
1330        alloc.free(&displaced);
1331        drop(alloc);
1332
1333        let caches_mut = self.kv_caches.get_mut(cache_id).expect("cache present");
1334        let max_blocks = caches_mut
1335            .first()
1336            .map(|cache| K::paged_block_indices(cache).len())
1337            .unwrap_or(0);
1338        let new_len = n_matched * block_size;
1339        let mut ctx = B::new_context();
1340        for cache in caches_mut.iter_mut() {
1341            {
1342                let indices = K::paged_block_indices_mut(cache);
1343                for (idx, &block) in matched.iter().enumerate() {
1344                    indices[idx] = block;
1345                }
1346            }
1347            K::set_len(cache, new_len);
1348            let padded = {
1349                let mut padded = K::paged_block_indices(cache).to_vec();
1350                padded.resize(max_blocks, 0);
1351                padded
1352            };
1353            if let Some(block_table) = K::block_table_mut(cache) {
1354                B::write_typed::<u32>(&mut ctx, block_table, &padded);
1355            }
1356            if let Some(context_lens) = K::context_lens_mut(cache) {
1357                B::write_typed::<u32>(&mut ctx, context_lens, &[new_len as u32]);
1358            }
1359        }
1360        B::sync(&mut ctx);
1361
1362        new_len
1363    }
1364
1365    fn register_prefix_cache(
1366        &mut self,
1367        cache_id: &str,
1368        all_tokens: &[u32],
1369        prior_cached_tokens: usize,
1370    ) {
1371        let Some(alloc_arc) = self.paged_block_alloc.as_ref() else {
1372            return;
1373        };
1374        let caches = match self.kv_caches.get(cache_id) {
1375            Some(caches) => caches,
1376            None => return,
1377        };
1378        let cache0 = match caches.first() {
1379            Some(cache) => cache,
1380            None => return,
1381        };
1382        let block_size = K::block_size(cache0);
1383        if block_size == 0 {
1384            return;
1385        }
1386
1387        let token_ids: Vec<ferrum_types::TokenId> = all_tokens
1388            .iter()
1389            .map(|&token| ferrum_types::TokenId::new(token))
1390            .collect();
1391        let hashes = block_hash_chain(&token_ids, block_size);
1392        if hashes.is_empty() {
1393            return;
1394        }
1395
1396        let start_block = prior_cached_tokens / block_size;
1397        let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
1398        for i in start_block..hashes.len().min(K::paged_block_indices(cache0).len()) {
1399            let block_end_token = (i + 1) * block_size;
1400            if block_end_token > K::len(cache0) {
1401                break;
1402            }
1403            alloc.register_block_hash(K::paged_block_indices(cache0)[i], hashes[i]);
1404        }
1405    }
1406
1407    fn prefix_cache_snapshot_json(&self) -> serde_json::Value {
1408        let (entries, block_size) = self
1409            .paged_block_alloc
1410            .as_ref()
1411            .and_then(|alloc| {
1412                let alloc = alloc.lock().ok()?;
1413                let block_size = self
1414                    .kv_caches
1415                    .values()
1416                    .find_map(|layers| layers.first().map(K::block_size))
1417                    .unwrap_or(16);
1418                Some((alloc.hash_table_size() as u64, block_size))
1419            })
1420            .unwrap_or((0, 16));
1421        let bytes_per_entry = (block_size
1422            * self.local_layer_count()
1423            * self.cfg.num_kv_heads
1424            * self.cfg.head_dim
1425            * K::BYTES_PER_ELEM
1426            * 2) as u64;
1427        serde_json::json!({
1428            "position": "real-kv-reuse",
1429            "source": "llama-family-paged-block-prefix-cache",
1430            "enabled": self.runtime_env.prefix_cache,
1431            "hits": self.prefix_cache_hits,
1432            "misses": self.prefix_cache_misses,
1433            "evictions": 0u64,
1434            "saved_prefill_tokens": self.prefix_cache_saved_prefill_tokens,
1435            "entries": entries,
1436            "bytes": entries.saturating_mul(bytes_per_entry),
1437            "block_size": block_size,
1438            "kv_dtype": K::NAME,
1439        })
1440    }
1441
1442    fn lora_projection_shape(
1443        &self,
1444        layer_index: usize,
1445        target_module: &str,
1446    ) -> Option<(usize, usize)> {
1447        let local_layer_index = self.local_layer_index_for_source(layer_index)?;
1448        let layer = self.layers.get(local_layer_index)?;
1449        match target_module {
1450            "qkv_proj" => Some((layer.qkv_proj.in_features(), layer.qkv_proj.out_features())),
1451            "o_proj" => Some((layer.o_proj.in_features(), layer.o_proj.out_features())),
1452            "gate_up_proj" => Some((
1453                layer.gate_up_proj.in_features(),
1454                layer.gate_up_proj.out_features(),
1455            )),
1456            "down_proj" => Some((
1457                layer.down_proj.in_features(),
1458                layer.down_proj.out_features(),
1459            )),
1460            _ => None,
1461        }
1462    }
1463
1464    fn validate_lora_adapter(&self, adapter: &RuntimeLoraAdapter<B>) -> Result<()> {
1465        if adapter.linears.is_empty() {
1466            return Err(ferrum_types::FerrumError::config(format!(
1467                "LoRA adapter {} has no runtime tensors",
1468                adapter.name
1469            )));
1470        }
1471        for linear in &adapter.linears {
1472            let layer_index = linear.layer_index.ok_or_else(|| {
1473                ferrum_types::FerrumError::config(format!(
1474                    "LoRA tensor for target {} must include model.layers.<N> in its tensor name",
1475                    linear.target_module
1476                ))
1477            })?;
1478            if layer_index < self.cfg.num_layers
1479                && !self.source_layer_range().contains(&layer_index)
1480            {
1481                continue;
1482            }
1483            let Some((expected_in, expected_out)) =
1484                self.lora_projection_shape(layer_index, &linear.target_module)
1485            else {
1486                return Err(ferrum_types::FerrumError::unsupported(format!(
1487                    "LoRA target {} is not supported by Llama-family runtime; supported targets: qkv_proj, o_proj, gate_up_proj, down_proj",
1488                    linear.target_module
1489                )));
1490            };
1491            if linear.in_features != expected_in || linear.out_features != expected_out {
1492                return Err(ferrum_types::FerrumError::config(format!(
1493                    "LoRA tensor shape mismatch for layer {} target {}: got out={} in={}, expected out={} in={}",
1494                    layer_index,
1495                    linear.target_module,
1496                    linear.out_features,
1497                    linear.in_features,
1498                    expected_out,
1499                    expected_in
1500                )));
1501            }
1502        }
1503        Ok(())
1504    }
1505
1506    fn ensure_lora_adapter_loaded(&mut self, adapter: ActiveLoraAdapter) -> Result<()> {
1507        if self.lora_adapters.contains_key(&adapter.name) {
1508            return Ok(());
1509        }
1510        let runtime = load_runtime_lora_adapter::<B>(&adapter)?;
1511        self.validate_lora_adapter(&runtime)?;
1512        self.lora_adapters.insert(adapter.name.clone(), runtime);
1513        Ok(())
1514    }
1515
1516    fn active_lora_adapter_for_cache(&self, cache_id: &str) -> Option<&RuntimeLoraAdapter<B>> {
1517        let adapter_name = self.lora_cache_adapters.get(cache_id)?;
1518        self.lora_adapters.get(adapter_name)
1519    }
1520
1521    fn active_lora_adapter_ptr_for_cache(
1522        &self,
1523        cache_id: &str,
1524    ) -> Option<*const RuntimeLoraAdapter<B>> {
1525        self.active_lora_adapter_for_cache(cache_id)
1526            .map(|adapter| adapter as *const RuntimeLoraAdapter<B>)
1527    }
1528
1529    fn lora_metrics_snapshot_json(&self) -> serde_json::Value {
1530        serde_json::json!({
1531            "enabled": !self.lora_adapters.is_empty(),
1532            "adapter_count": self.lora_adapters.len() as u64,
1533            "active_cache_bindings": self.lora_cache_adapters.len() as u64,
1534            "projection_applications": self.lora_projection_applications,
1535            "position": "real-inference",
1536            "source": "llama-family-runtime-lora",
1537        })
1538    }
1539
1540    /// Run one transformer layer. Mutates `residual` in place.
1541    ///
1542    /// `pos_offset` is the absolute position of token 0 in this batch
1543    /// (decode: `pos`; prefill: 0). `tokens` is the batch size.
1544    #[allow(clippy::too_many_arguments)]
1545    pub(crate) fn forward_layer(
1546        &mut self,
1547        ctx: &mut B::Context,
1548        li: usize,
1549        cache_id: &str,
1550        residual: &mut B::Buffer,
1551        pos_offset: usize,
1552        tokens: usize,
1553    ) {
1554        let source_li = self.source_layer_index(li);
1555        let layer = &self.layers[li];
1556        let cfg = &self.cfg;
1557        let h = cfg.hidden_size;
1558        let nh = cfg.num_heads;
1559        let nkv = cfg.num_kv_heads;
1560        let hd = cfg.head_dim;
1561        let im = cfg.intermediate_size;
1562        let eps = cfg.rms_norm_eps;
1563        let q_dim = nh * hd;
1564        let kv_dim = nkv * hd;
1565
1566        // 1. Input RMSNorm
1567        let _t0 = if self.runtime_env.decode_op_profile {
1568            B::sync(ctx);
1569            Some(std::time::Instant::now())
1570        } else {
1571            None
1572        };
1573        B::rms_norm(
1574            ctx,
1575            residual,
1576            &layer.input_ln_w,
1577            eps,
1578            &mut self.scratch.norm_out,
1579            tokens,
1580            h,
1581        );
1582        if let Some(t0) = _t0 {
1583            B::sync(ctx);
1584            NORM_TIME_US.fetch_add(
1585                t0.elapsed().as_micros() as u64,
1586                std::sync::atomic::Ordering::Relaxed,
1587            );
1588            NORM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1589        }
1590
1591        // 2. Fused QKV projection (Linear dispatches to Dense/GPTQ/AWQ/GGUF)
1592        let _t0 = if self.runtime_env.decode_op_profile {
1593            B::sync(ctx);
1594            Some(std::time::Instant::now())
1595        } else {
1596            None
1597        };
1598        layer.qkv_proj.forward(
1599            ctx,
1600            &self.scratch.norm_out,
1601            &mut self.scratch.qkv_out,
1602            tokens,
1603        );
1604        if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1605            // SAFETY: `lora_adapters` is not mutated during a forward pass;
1606            // the pointer is used only for this immediate read-only call.
1607            let applied = unsafe { &*adapter }
1608                .apply_projection(
1609                    ctx,
1610                    source_li,
1611                    "qkv_proj",
1612                    &self.scratch.norm_out,
1613                    &mut self.scratch.qkv_out,
1614                    tokens,
1615                )
1616                .expect("validated LoRA qkv_proj");
1617            self.lora_projection_applications += applied as u64;
1618        }
1619        if let Some(t0) = _t0 {
1620            B::sync(ctx);
1621            MATMUL_TIME_US.fetch_add(
1622                t0.elapsed().as_micros() as u64,
1623                std::sync::atomic::Ordering::Relaxed,
1624            );
1625            MATMUL_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1626        }
1627
1628        // 3-5. Fused split-QKV + QK-norm + RoPE + cache-write.
1629        //
1630        // Single Metal dispatch replaces the (split_qkv → 3× qk_norm_rope
1631        // → kv_cache_append_head_major) five-launch chain on the decode
1632        // hot path. Reads qkv_out once, writes Q to head-major scratch
1633        // and K/V straight into the pre-allocated KV cache slot at
1634        // `cache_len + tok`. Saves 4 dispatches per layer when the
1635        // backend implements the fused kernel; CPU and other backends
1636        // keep using the unfused chain via the Unsupported fallbacks.
1637        //
1638        // qk_mode: 1 = norm + half-split RoPE (Qwen3/Qwen HF);
1639        //          2 = half-split RoPE only;
1640        //          3 = interleaved RoPE only (GGUF LLaMA / llama.cpp layout).
1641        // V always passes apply_norm=0.
1642        let qk_mode: i32 = if cfg.has_qk_norm {
1643            1
1644        } else if cfg.rope_interleaved {
1645            3
1646        } else {
1647            2
1648        };
1649        let dummy = &layer.input_ln_w;
1650        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
1651        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
1652
1653        // Grab the per-layer KV cache up front so the deepest fusion can
1654        // write K/V straight into it.
1655        //
1656        // Paged mode: also need this layer's shared pool buffers
1657        // (self.paged_pools[li]). The pool is a separate field from
1658        // kv_caches, so we take a raw pointer to its (k, v) here while
1659        // we still hold &mut self, then deref via unsafe inside the
1660        // paged dispatch below. Safety: paged_pools is allocated once
1661        // and never resized; we don't touch self.paged_pools while the
1662        // pointer is in use.
1663        let paged_pool_ptr: Option<(*mut B::Buffer, *mut B::Buffer)> =
1664            if let Some(pools) = self.paged_pools.as_mut() {
1665                let pool = &mut pools[li];
1666                Some((&mut pool.0 as *mut _, &mut pool.1 as *mut _))
1667            } else {
1668                None
1669            };
1670        let caches = self
1671            .kv_caches
1672            .get_mut(cache_id)
1673            .expect("ensure_kv must be called before forward_layer");
1674        // Read shared metadata (variant-agnostic) once. The K-aware
1675        // attn section below re-borrows the right enum variant.
1676        let cache_len_before = K::len(&caches[li]);
1677        let cache_capacity = K::capacity(&caches[li]);
1678        let cache_block_size = K::block_size(&caches[li]);
1679
1680        // Defense in depth: refuse to write past the KV buffer. The
1681        // graceful path is the caller pre-checking via `kv_capacity()`
1682        // and either compacting or refusing the request; this panic only
1683        // fires when that contract is broken (and silent overflow would
1684        // otherwise corrupt the cache + adjacent allocations).
1685        if cache_len_before + tokens > cache_capacity {
1686            panic!(
1687                "KV cache overflow on source layer {source_li} (local layer {li}): would write tokens [{cache_len_before}..{}) but capacity is {cache_capacity} (cache_id={cache_id:?}). Increase FERRUM_KV_CAPACITY or call /clear in the REPL.",
1688                cache_len_before + tokens
1689            );
1690        }
1691
1692        // Paged path: K::paged_write fuses split_qkv_norm_rope + cache append
1693        // (FP16: into_paged_cache; INT8: split_qkv_norm_rope + int8_kv_append_paged).
1694        // K::paged_decode_attention reads from layer-local INT8 buffers or the
1695        // shared FP16 pool depending on K. Then K-agnostic post-attn tail.
1696        if cache_block_size > 0 {
1697            let (pool_k_ptr, pool_v_ptr) =
1698                paged_pool_ptr.expect("paged_pools must be allocated when block_size > 0");
1699            // SAFETY: paged_pools is allocated once and never resized; the
1700            // raw pointers don't outlive this method scope.
1701            let pool_k = unsafe { &mut *pool_k_ptr };
1702            let pool_v = unsafe { &mut *pool_v_ptr };
1703
1704            K::paged_write(
1705                ctx,
1706                &mut caches[li],
1707                &self.scratch.qkv_out,
1708                q_norm_w,
1709                k_norm_w,
1710                &self.rope.cos,
1711                &self.rope.sin,
1712                &mut self.scratch.q_head_major,
1713                &mut self.scratch.k_head_major,
1714                &mut self.scratch.v_head_major,
1715                pool_k,
1716                pool_v,
1717                tokens,
1718                nh,
1719                nkv,
1720                hd,
1721                pos_offset,
1722                eps,
1723                qk_mode,
1724            )
1725            .expect("K::paged_write");
1726
1727            let new_len = cache_len_before + tokens;
1728            K::set_len(&mut caches[li], new_len);
1729
1730            let pool_k_imm = unsafe { &*pool_k_ptr };
1731            let pool_v_imm = unsafe { &*pool_v_ptr };
1732            K::paged_decode_attention(
1733                ctx,
1734                &mut caches[li],
1735                &self.scratch.q_head_major,
1736                pool_k_imm,
1737                pool_v_imm,
1738                &mut self.scratch.attn_head_major_out,
1739                nh,
1740                nkv,
1741                hd,
1742                new_len,
1743                tokens,
1744            )
1745            .expect("K::paged_decode_attention");
1746
1747            return self.forward_layer_post_attn(ctx, li, cache_id, residual, tokens);
1748        }
1749
1750        // Non-paged (contig) path. INT8 path doesn't reach here:
1751        // KvInt8::alloc_contig panics in ensure_kv.
1752        let _qkr_t0 = if self.runtime_env.decode_op_profile {
1753            B::sync(ctx);
1754            Some(std::time::Instant::now())
1755        } else {
1756            None
1757        };
1758        K::contig_write(
1759            ctx,
1760            &mut caches[li],
1761            &self.scratch.qkv_out,
1762            q_norm_w,
1763            k_norm_w,
1764            &self.rope.cos,
1765            &self.rope.sin,
1766            &mut self.scratch.q_head_major,
1767            &mut self.scratch.k_head_major,
1768            &mut self.scratch.v_head_major,
1769            &mut self.scratch.q_buf,
1770            &mut self.scratch.k_buf,
1771            &mut self.scratch.v_buf,
1772            tokens,
1773            nh,
1774            nkv,
1775            hd,
1776            pos_offset,
1777            eps,
1778            qk_mode,
1779        )
1780        .expect("K::contig_write");
1781        if let Some(t0) = _qkr_t0 {
1782            B::sync(ctx);
1783            QKR_TIME_US.fetch_add(
1784                t0.elapsed().as_micros() as u64,
1785                std::sync::atomic::Ordering::Relaxed,
1786            );
1787            QKR_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1788        }
1789        let new_len = cache_len_before + tokens;
1790        K::set_len(&mut caches[li], new_len);
1791        let kv_stride = cache_capacity;
1792
1793        let _attn_t0 = if self.runtime_env.decode_op_profile {
1794            B::sync(ctx);
1795            Some(std::time::Instant::now())
1796        } else {
1797            None
1798        };
1799        let attn_cfg = ferrum_kernels::backend::AttnConfig {
1800            num_heads: nh,
1801            num_kv_heads: nkv,
1802            head_dim: hd,
1803            causal: true,
1804            scale: 1.0 / (hd as f32).sqrt(),
1805            kv_seq_stride: kv_stride,
1806            sliding_window: cfg.sliding_window,
1807        };
1808        K::contig_decode_attention(
1809            ctx,
1810            &caches[li],
1811            &self.scratch.q_head_major,
1812            &mut self.scratch.attn_head_major_out,
1813            attn_cfg,
1814            tokens,
1815            pos_offset,
1816        )
1817        .expect("K::contig_decode_attention");
1818        let _ = q_dim;
1819        let _ = kv_dim;
1820        let _ = dummy;
1821        if let Some(t0) = _attn_t0 {
1822            B::sync(ctx);
1823            ATTN_TIME_US.fetch_add(
1824                t0.elapsed().as_micros() as u64,
1825                std::sync::atomic::Ordering::Relaxed,
1826            );
1827            ATTN_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1828        }
1829
1830        self.forward_layer_post_attn(ctx, li, cache_id, residual, tokens);
1831    }
1832
1833    /// Post-attention tail of `forward_layer`: untranspose Q (if needed),
1834    /// O-proj, fused residual+post_norm, gate_up_proj, SwiGLU, down_proj,
1835    /// final residual add. K-agnostic — reads `self.scratch.attn_head_major_out`
1836    /// which both the FP16 and INT8 attn paths populate.
1837    pub(crate) fn forward_layer_post_attn(
1838        &mut self,
1839        ctx: &mut B::Context,
1840        li: usize,
1841        cache_id: &str,
1842        residual: &mut B::Buffer,
1843        tokens: usize,
1844    ) {
1845        let source_li = self.source_layer_index(li);
1846        let layer = &self.layers[li];
1847        let cfg = &self.cfg;
1848        let h = cfg.hidden_size;
1849        let nh = cfg.num_heads;
1850        let hd = cfg.head_dim;
1851        let im = cfg.intermediate_size;
1852        let eps = cfg.rms_norm_eps;
1853
1854        // 7. Untranspose head-major → token-major for O-proj input.
1855        let attn_token_major = if tokens == 1 {
1856            &self.scratch.attn_head_major_out
1857        } else {
1858            B::transpose_head_to_token(
1859                ctx,
1860                &self.scratch.attn_head_major_out,
1861                &mut self.scratch.attn_flat,
1862                tokens,
1863                nh,
1864                hd,
1865            );
1866            &self.scratch.attn_flat
1867        };
1868
1869        // 8. O projection.
1870        layer
1871            .o_proj
1872            .forward(ctx, attn_token_major, &mut self.scratch.o_proj_out, tokens);
1873        if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1874            // SAFETY: see qkv_proj LoRA application above.
1875            let applied = unsafe { &*adapter }
1876                .apply_projection(
1877                    ctx,
1878                    source_li,
1879                    "o_proj",
1880                    attn_token_major,
1881                    &mut self.scratch.o_proj_out,
1882                    tokens,
1883                )
1884                .expect("validated LoRA o_proj");
1885            self.lora_projection_applications += applied as u64;
1886        }
1887
1888        // 9. Fused residual-add + post-attention RMSNorm.
1889        B::fused_add_rms_norm(
1890            ctx,
1891            residual,
1892            &self.scratch.o_proj_out,
1893            &layer.post_ln_w,
1894            eps,
1895            &mut self.scratch.norm_out,
1896            tokens,
1897            h,
1898        );
1899
1900        // 10. Fused gate+up projection.
1901        layer.gate_up_proj.forward(
1902            ctx,
1903            &self.scratch.norm_out,
1904            &mut self.scratch.gate_up_out,
1905            tokens,
1906        );
1907        if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1908            // SAFETY: see qkv_proj LoRA application above.
1909            let applied = unsafe { &*adapter }
1910                .apply_projection(
1911                    ctx,
1912                    source_li,
1913                    "gate_up_proj",
1914                    &self.scratch.norm_out,
1915                    &mut self.scratch.gate_up_out,
1916                    tokens,
1917                )
1918                .expect("validated LoRA gate_up_proj");
1919            self.lora_projection_applications += applied as u64;
1920        }
1921
1922        // 11. SwiGLU: silu(gate) * up.
1923        B::fused_silu_mul_split(
1924            ctx,
1925            &self.scratch.gate_up_out,
1926            &mut self.scratch.silu_out,
1927            tokens,
1928            im,
1929        );
1930
1931        // 12. Down projection.
1932        layer.down_proj.forward(
1933            ctx,
1934            &self.scratch.silu_out,
1935            &mut self.scratch.mlp_out,
1936            tokens,
1937        );
1938        if let Some(adapter) = self.active_lora_adapter_ptr_for_cache(cache_id) {
1939            // SAFETY: see qkv_proj LoRA application above.
1940            let applied = unsafe { &*adapter }
1941                .apply_projection(
1942                    ctx,
1943                    source_li,
1944                    "down_proj",
1945                    &self.scratch.silu_out,
1946                    &mut self.scratch.mlp_out,
1947                    tokens,
1948                )
1949                .expect("validated LoRA down_proj");
1950            self.lora_projection_applications += applied as u64;
1951        }
1952
1953        // 13. Final residual add.
1954        B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
1955    }
1956
1957    /// Multi-position decode-verify: run one forward pass over `tokens`
1958    /// starting at the cache's current end position, write their K/V
1959    /// into the KV cache, and return logits for ALL `tokens.len()`
1960    /// positions as a flat `Vec<f32>` of length `seq_len * vocab_size`.
1961    ///
1962    /// Used by speculative decoding: target receives
1963    /// `[last_token, draft_0, ..., draft_{N-1}]` (N+1 inputs) and produces
1964    /// N+1 logit rows in a single forward instead of N+1 sequential
1965    /// decode() calls. Positions are implicit — the model looks up
1966    /// `pos_offset = cache.len` the same way prefill_internal does, so
1967    /// chunked prefill semantics carry over for free.
1968    pub fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1969        let seq_len = tokens.len();
1970        assert!(seq_len > 0, "forward_verify called with empty tokens");
1971        self.ensure_scratch(seq_len);
1972        self.ensure_kv(cache_id);
1973
1974        let h = self.cfg.hidden_size;
1975        let vocab = self.cfg.vocab_size;
1976
1977        let pos_offset = self
1978            .kv_caches
1979            .get(cache_id)
1980            .and_then(|layers| layers.first())
1981            .map(|c| K::len(c))
1982            .unwrap_or(0);
1983
1984        let mut ctx = B::new_context();
1985        let mut residual = self
1986            .scratch
1987            .residual
1988            .take()
1989            .expect("scratch residual missing (previous call didn't restore)");
1990
1991        let embed = self
1992            .embed
1993            .as_ref()
1994            .expect("forward_verify called on backbone-only model (no embed)");
1995        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
1996
1997        for li in self.local_layer_indices() {
1998            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1999        }
2000
2001        // RMSNorm on ALL seq_len positions (prefill_internal only norms
2002        // the last one; verify needs the full grid).
2003        B::rms_norm(
2004            &mut ctx,
2005            &residual,
2006            &self.final_norm_w,
2007            self.cfg.rms_norm_eps,
2008            &mut self.scratch.norm_out,
2009            seq_len,
2010            h,
2011        );
2012
2013        // LM head applied to all positions → `seq_len * vocab` logits.
2014        // Reuses the existing `batch_logits` scratch (sized max_tokens *
2015        // vocab) so no extra allocation.
2016        let lm_head = self
2017            .lm_head
2018            .as_ref()
2019            .expect("forward_verify called on backbone-only model (no lm_head)");
2020        lm_head.forward(
2021            &mut ctx,
2022            &self.scratch.norm_out,
2023            &mut self.scratch.batch_logits,
2024            seq_len,
2025        );
2026
2027        B::sync(&mut ctx);
2028        self.scratch.residual = Some(residual);
2029        B::to_vec(&self.scratch.batch_logits, seq_len * vocab)
2030    }
2031
2032    /// Prefill: process `tokens` prompt tokens in a single batch, return
2033    /// `[vocab_size]` logits for the last position.
2034    ///
2035    /// Supports incremental prefill: if the KV cache for `cache_id` already
2036    /// contains earlier tokens, the new chunk's positions are computed as
2037    /// `[kv_len, kv_len + tokens.len())` so RoPE and causal masking stay
2038    /// aligned. Used by the engine's chunked-prefill path.
2039    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2040        assert!(!tokens.is_empty(), "prefill called with empty token list");
2041        self.ensure_kv(cache_id);
2042
2043        let cache_len_before = self
2044            .kv_caches
2045            .get(cache_id)
2046            .and_then(|layers| layers.first())
2047            .map(K::len)
2048            .unwrap_or(0);
2049        let mut cached_prefix_tokens = if self.runtime_env.prefix_cache && cache_len_before == 0 {
2050            self.try_acquire_prefix_cache(cache_id, tokens)
2051        } else {
2052            0
2053        };
2054        if cached_prefix_tokens >= tokens.len() {
2055            let block_size = self
2056                .kv_caches
2057                .get(cache_id)
2058                .and_then(|layers| layers.first())
2059                .map(K::block_size)
2060                .unwrap_or(16);
2061            cached_prefix_tokens = cached_prefix_tokens
2062                .saturating_sub(block_size)
2063                .min(tokens.len() - 1);
2064        }
2065        if self.runtime_env.prefix_cache && cache_len_before == 0 {
2066            self.record_prefix_cache_probe(cached_prefix_tokens);
2067        }
2068
2069        if cached_prefix_tokens > 0 {
2070            let caches_mut = self.kv_caches.get_mut(cache_id).expect("cache present");
2071            let mut ctx_tmp = B::new_context();
2072            for cache in caches_mut.iter_mut() {
2073                if K::len(cache) != cached_prefix_tokens {
2074                    K::set_len(cache, cached_prefix_tokens);
2075                    if let Some(context_lens) = K::context_lens_mut(cache) {
2076                        B::write_typed::<u32>(
2077                            &mut ctx_tmp,
2078                            context_lens,
2079                            &[cached_prefix_tokens as u32],
2080                        );
2081                    }
2082                }
2083            }
2084            B::sync(&mut ctx_tmp);
2085        }
2086
2087        let suffix_tokens = &tokens[cached_prefix_tokens..];
2088        let seq_len = suffix_tokens.len();
2089        assert!(
2090            seq_len > 0,
2091            "prefix cache must leave at least one suffix token"
2092        );
2093        self.ensure_scratch(seq_len);
2094
2095        // Starting position for this chunk — 0 for a fresh prefill, kv_len
2096        // for the second+ chunk of a split prefill.
2097        let pos_offset = self
2098            .kv_caches
2099            .get(cache_id)
2100            .and_then(|layers| layers.first())
2101            .map(|c| K::len(c))
2102            .unwrap_or(0);
2103
2104        let h = self.cfg.hidden_size;
2105        let vocab = self.cfg.vocab_size;
2106        let mut ctx = B::new_context();
2107
2108        // Move `residual` out of `scratch` to work around the borrow checker:
2109        // `forward_layer` re-borrows `&mut self` to reach `self.layers` /
2110        // `self.kv_caches`, which would conflict with an outstanding
2111        // `&mut self.scratch.residual`. Use Option::take to move it out
2112        // (no placeholder alloc → no transient cuMemFreeAsync that could
2113        // corrupt stream pool state after graph ops on Blackwell).
2114        let mut residual = self
2115            .scratch
2116            .residual
2117            .take()
2118            .expect("scratch residual missing (previous call didn't restore)");
2119        let embed = self
2120            .embed
2121            .as_ref()
2122            .expect("prefill_internal called on backbone-only model (no embed)");
2123        B::embedding_lookup(&mut ctx, embed, suffix_tokens, &mut residual, h);
2124
2125        let prefill_profile = self.runtime_env.prefill_op_profile;
2126        let prefill_t0 = if prefill_profile {
2127            B::sync(&mut ctx);
2128            Some(std::time::Instant::now())
2129        } else {
2130            None
2131        };
2132
2133        for li in self.local_layer_indices() {
2134            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2135        }
2136
2137        if let Some(t0) = prefill_t0 {
2138            B::sync(&mut ctx);
2139            let total_us = t0.elapsed().as_micros() as u64;
2140            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2141            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2142            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2143            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2144            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2145            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2146            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2147            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2148            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2149            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2150            eprintln!(
2151                "[prefill-profile] tokens={} layers total={} ms",
2152                seq_len,
2153                total_us / 1000
2154            );
2155            let bucket = |label: &str, n: u64, us: u64| {
2156                if n > 0 {
2157                    eprintln!(
2158                        "[prefill-profile] {label}: {} calls {} ms (avg {} us)",
2159                        n,
2160                        us / 1000,
2161                        us / n
2162                    );
2163                }
2164            };
2165            bucket("flash_attn", attn_n, attn_us);
2166            bucket("qk_norm_rope", qkr_n, qkr_us);
2167            bucket("matmuls", mm_n, mm_us);
2168            bucket("norms", norm_n, norm_us);
2169            bucket("other", other_n, other_us);
2170        }
2171
2172        // Take the last token's hidden state: residual[(seq_len-1)*h .. seq_len*h]
2173        B::copy_slice(
2174            &mut ctx,
2175            &residual,
2176            (seq_len - 1) * h,
2177            &mut self.scratch.last_hidden,
2178            0,
2179            h,
2180        );
2181
2182        // Final RMSNorm on the last hidden.
2183        B::rms_norm(
2184            &mut ctx,
2185            &self.scratch.last_hidden,
2186            &self.final_norm_w,
2187            self.cfg.rms_norm_eps,
2188            &mut self.scratch.last_normed,
2189            1,
2190            h,
2191        );
2192
2193        // LM head (m=1 — triggers GEMV on MetalBackend).
2194        let lm_head = self
2195            .lm_head
2196            .as_ref()
2197            .expect("prefill_internal called on backbone-only model (no lm_head)");
2198        lm_head.forward(
2199            &mut ctx,
2200            &self.scratch.last_normed,
2201            &mut self.scratch.logits,
2202            1,
2203        );
2204
2205        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
2206        // buffer's CPU pointer without flushing the command buffer, so the
2207        // GPU must complete all pending work first or we read stale/random
2208        // data. CUDA's to_vec does an internal stream.synchronize, making
2209        // the call redundant there (~50µs/step cost), but correctness on
2210        // Metal requires the explicit flush here.
2211        B::sync(&mut ctx);
2212
2213        // Restore residual into scratch for reuse on the next call.
2214        self.scratch.residual = Some(residual);
2215        if self.runtime_env.prefix_cache && cache_len_before == 0 {
2216            self.register_prefix_cache(cache_id, tokens, cached_prefix_tokens);
2217        }
2218
2219        B::to_vec(&self.scratch.logits, vocab)
2220    }
2221
2222    /// Decode: process 1 token at position `pos`, return `[vocab_size]` logits.
2223    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2224        self.ensure_scratch(1);
2225        self.ensure_kv(cache_id);
2226
2227        let h = self.cfg.hidden_size;
2228        let vocab = self.cfg.vocab_size;
2229
2230        // Context creation is cheap (CUDA reuses the process-global stream).
2231        // The captured graph lives in a process-global slot, not on ctx.
2232        let mut ctx = B::new_context();
2233
2234        // Graph capture is opt-in via FERRUM_CUDA_GRAPH=1. Replay is currently
2235        // single-request-only on Blackwell + CUDA 12.8 (see
2236        // docs/phase-e-cuda-status.md). In pure eager mode, we skip the
2237        // per-step device-state memcpy_htod trio entirely.
2238        const GRAPH_WARMUP: usize = 3;
2239        let graph_enabled = self.runtime_env.cuda_graph;
2240
2241        if graph_enabled {
2242            // Refresh device-side dynamic state (token/pos/kv_len) before
2243            // replay — captured graph reads these from device buffers.
2244            B::set_decode_state(&mut ctx, token, pos);
2245
2246            // Fast path: graph replay (if available). Single-item path
2247            // uses key=SINGLE_ITEM_GRAPH_KEY (0) — separate from the
2248            // batched path's m_padded keys.
2249            match B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY) {
2250                Ok(true) => {
2251                    B::sync(&mut ctx);
2252                    return B::to_vec(&self.scratch.logits, vocab);
2253                }
2254                Ok(false) => { /* no graph yet, fall through to eager */ }
2255                Err(_) => { /* backend error or unsupported, eager */ }
2256            }
2257        }
2258
2259        let should_capture =
2260            graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
2261
2262        if should_capture {
2263            B::set_dev_state_mode(&mut ctx, true);
2264            if B::begin_graph_capture(&mut ctx).is_err() {
2265                self.graph_capture_failed = true;
2266                B::set_dev_state_mode(&mut ctx, false);
2267            }
2268        }
2269
2270        // Eager forward (records into graph if capture is active).
2271        // mem::replace needs a placeholder. B::alloc(0) was our choice but
2272        // cuMemAllocFromPoolAsync(stream, 0) can return CUDA_ERROR_INVALID_VALUE
2273        // on Blackwell after graph replay corrupts the pool state. Size-1 is
2274        // always valid and costs 2 bytes of transient VRAM per decode step.
2275        let mut residual = self
2276            .scratch
2277            .residual
2278            .take()
2279            .expect("scratch residual missing (previous call didn't restore)");
2280        let embed = self
2281            .embed
2282            .as_ref()
2283            .expect("decode_internal called on backbone-only model (no embed)");
2284        B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
2285
2286        // Per-layer wall-time profile (env-gated, off by default — adds
2287        // a B::sync between layers which serializes the pipeline). Helps
2288        // localise non-matmul bottlenecks during perf work.
2289        let layer_profile = self.runtime_env.decode_layer_profile;
2290        let mut layer_times = if layer_profile {
2291            Some(Vec::with_capacity(self.local_layer_count()))
2292        } else {
2293            None
2294        };
2295
2296        for li in self.local_layer_indices() {
2297            if layer_profile {
2298                let t0 = std::time::Instant::now();
2299                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2300                B::sync(&mut ctx);
2301                let elapsed_us = t0.elapsed().as_micros() as u64;
2302                if let Some(v) = layer_times.as_mut() {
2303                    v.push(elapsed_us);
2304                }
2305            } else {
2306                self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2307            }
2308        }
2309        if let Some(times) = layer_times.take() {
2310            let sum: u64 = times.iter().sum();
2311            let avg = sum / times.len() as u64;
2312            let mn = *times.iter().min().unwrap_or(&0);
2313            let mx = *times.iter().max().unwrap_or(&0);
2314            eprintln!(
2315                "[layer-profile] {} layers total={} ms avg={} us min={} us max={} us",
2316                times.len(),
2317                sum / 1000,
2318                avg,
2319                mn,
2320                mx
2321            );
2322            for (i, t) in times.iter().enumerate() {
2323                eprint!("L{i}={}ms ", t / 1000);
2324                if (i + 1) % 6 == 0 {
2325                    eprintln!();
2326                }
2327            }
2328            eprintln!();
2329            let attn_us = ATTN_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2330            let attn_n = ATTN_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2331            let qkr_us = QKR_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2332            let qkr_n = QKR_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2333            let mm_us = MATMUL_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2334            let mm_n = MATMUL_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2335            let norm_us = NORM_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2336            let norm_n = NORM_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2337            let other_us = OTHER_TIME_US.swap(0, std::sync::atomic::Ordering::Relaxed);
2338            let other_n = OTHER_CALLS.swap(0, std::sync::atomic::Ordering::Relaxed);
2339            eprintln!(
2340                "[op-profile] flash_attn: {} calls {} ms (avg {} us)",
2341                attn_n,
2342                attn_us / 1000,
2343                if attn_n > 0 { attn_us / attn_n } else { 0 }
2344            );
2345            eprintln!(
2346                "[op-profile] qk_norm_rope: {} calls {} ms (avg {} us)",
2347                qkr_n,
2348                qkr_us / 1000,
2349                if qkr_n > 0 { qkr_us / qkr_n } else { 0 }
2350            );
2351            eprintln!(
2352                "[op-profile] matmuls (Linear::forward): {} calls {} ms (avg {} us)",
2353                mm_n,
2354                mm_us / 1000,
2355                if mm_n > 0 { mm_us / mm_n } else { 0 }
2356            );
2357            eprintln!(
2358                "[op-profile] norms (rms+fused_add_rms): {} calls {} ms (avg {} us)",
2359                norm_n,
2360                norm_us / 1000,
2361                if norm_n > 0 { norm_us / norm_n } else { 0 }
2362            );
2363            eprintln!(
2364                "[op-profile] other (split_qkv, kv_append, transpose, silu, add): {} calls {} ms (avg {} us)",
2365                other_n, other_us / 1000, if other_n > 0 { other_us / other_n } else { 0 }
2366            );
2367        }
2368
2369        B::rms_norm(
2370            &mut ctx,
2371            &residual,
2372            &self.final_norm_w,
2373            self.cfg.rms_norm_eps,
2374            &mut self.scratch.last_normed,
2375            1,
2376            h,
2377        );
2378
2379        let lm_head = self
2380            .lm_head
2381            .as_ref()
2382            .expect("decode_internal called on backbone-only model (no lm_head)");
2383        lm_head.forward(
2384            &mut ctx,
2385            &self.scratch.last_normed,
2386            &mut self.scratch.logits,
2387            1,
2388        );
2389
2390        if should_capture && !self.graph_capture_failed {
2391            if B::end_graph_capture(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
2392                self.graph_capture_failed = true;
2393            } else {
2394                // Stream capture mode RECORDS ops into the graph without
2395                // executing them. scratch.logits still holds the previous
2396                // step's value. Replay the just-captured graph once to
2397                // actually execute and produce this step's logits. Without
2398                // this, the capture step's to_vec returns stale logits,
2399                // yielding a 1-token offset in the generated sequence.
2400                if B::replay_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY).is_err() {
2401                    self.graph_capture_failed = true;
2402                }
2403            }
2404            B::set_dev_state_mode(&mut ctx, false);
2405        } else {
2406            self.graph_warmup += 1;
2407        }
2408
2409        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
2410        // buffer's CPU pointer without flushing the command buffer, so the
2411        // GPU must complete all pending work first or we read stale/random
2412        // data. CUDA's to_vec does an internal stream.synchronize, making
2413        // the call redundant there (~50µs/step cost), but correctness on
2414        // Metal requires the explicit flush here.
2415        B::sync(&mut ctx);
2416        self.scratch.residual = Some(residual);
2417
2418        B::to_vec(&self.scratch.logits, vocab)
2419    }
2420
2421    fn stage_tokens_to_hidden(
2422        &mut self,
2423        cache_id: &str,
2424        tokens: &[u32],
2425        pos_offset: usize,
2426    ) -> Vec<f32> {
2427        let seq_len = tokens.len();
2428        assert!(seq_len > 0, "stage token forward called with zero length");
2429        self.ensure_scratch(seq_len);
2430        self.ensure_kv(cache_id);
2431
2432        let h = self.cfg.hidden_size;
2433        let mut ctx = B::new_context();
2434        let mut residual = self
2435            .scratch
2436            .residual
2437            .take()
2438            .expect("scratch residual missing (previous call didn't restore)");
2439        let embed = self
2440            .embed
2441            .as_ref()
2442            .expect("stage token forward called on stage without embedding");
2443        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
2444
2445        for li in self.local_layer_indices() {
2446            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2447        }
2448
2449        B::sync(&mut ctx);
2450        let out = B::to_vec(&residual, seq_len * h);
2451        self.scratch.residual = Some(residual);
2452        out
2453    }
2454
2455    /// Run the first layer-split stage over prompt tokens and return all
2456    /// output hidden states for the next stage.
2457    pub fn prefill_stage_tokens_to_hidden(
2458        &mut self,
2459        cache_id: &str,
2460        tokens: &[u32],
2461        pos_offset: usize,
2462    ) -> Vec<f32> {
2463        self.stage_tokens_to_hidden(cache_id, tokens, pos_offset)
2464    }
2465
2466    /// Decode-side first-stage token forward. Returns one hidden row for the
2467    /// next stage.
2468    pub fn decode_stage_token_to_hidden(
2469        &mut self,
2470        cache_id: &str,
2471        token: u32,
2472        pos: u32,
2473    ) -> Vec<f32> {
2474        self.stage_tokens_to_hidden(cache_id, &[token], pos as usize)
2475    }
2476
2477    fn stage_hidden_from_host(
2478        &mut self,
2479        cache_id: &str,
2480        hidden: &[f32],
2481        seq_len: usize,
2482        pos_offset: usize,
2483    ) -> Vec<f32> {
2484        self.stage_hidden_from_host_with_timing(cache_id, hidden, seq_len, pos_offset)
2485            .0
2486    }
2487
2488    fn stage_hidden_from_host_with_timing(
2489        &mut self,
2490        cache_id: &str,
2491        hidden: &[f32],
2492        seq_len: usize,
2493        pos_offset: usize,
2494    ) -> (Vec<f32>, LlamaStageHiddenBridgeTiming) {
2495        let h = self.cfg.hidden_size;
2496        assert_eq!(
2497            hidden.len(),
2498            seq_len * h,
2499            "hidden length {} != seq_len * hidden_size {}",
2500            hidden.len(),
2501            seq_len * h
2502        );
2503        assert!(seq_len > 0, "stage hidden forward called with zero length");
2504
2505        self.ensure_scratch(seq_len);
2506        self.ensure_kv(cache_id);
2507
2508        let mut ctx = B::new_context();
2509        let mut residual = self
2510            .scratch
2511            .residual
2512            .take()
2513            .expect("scratch residual missing (previous call didn't restore)");
2514
2515        let bridge_t0 = std::time::Instant::now();
2516        let host_copy_t0 = std::time::Instant::now();
2517        let hidden_buf = B::from_slice(hidden);
2518        let host_copy_us = elapsed_micros_u64_floor1(host_copy_t0);
2519        let device_copy_t0 = std::time::Instant::now();
2520        B::copy_slice(&mut ctx, &hidden_buf, 0, &mut residual, 0, seq_len * h);
2521        let device_copy_us = elapsed_micros_u64_floor1(device_copy_t0);
2522        let bridge_timing = LlamaStageHiddenBridgeTiming {
2523            bridge_us: elapsed_micros_u64_floor1(bridge_t0),
2524            host_copy_us,
2525            device_copy_us,
2526        };
2527
2528        for li in self.local_layer_indices() {
2529            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2530        }
2531
2532        B::sync(&mut ctx);
2533        let out = B::to_vec(&residual, seq_len * h);
2534        self.scratch.residual = Some(residual);
2535        (out, bridge_timing)
2536    }
2537
2538    /// Run this layer-split stage over prefill hidden states supplied by the
2539    /// previous stage. Returns all output hidden states in row-major
2540    /// `[seq_len, hidden_size]` form for the next stage.
2541    pub fn prefill_stage_hidden_from_host(
2542        &mut self,
2543        cache_id: &str,
2544        hidden: &[f32],
2545        seq_len: usize,
2546        pos_offset: usize,
2547    ) -> Vec<f32> {
2548        self.stage_hidden_from_host(cache_id, hidden, seq_len, pos_offset)
2549    }
2550
2551    /// Decode-side companion for one hidden state row supplied by the previous
2552    /// stage. Returns this stage's output hidden row.
2553    pub fn decode_stage_hidden_from_host(
2554        &mut self,
2555        cache_id: &str,
2556        hidden: &[f32],
2557        pos: u32,
2558    ) -> Vec<f32> {
2559        self.stage_hidden_from_host(cache_id, hidden, 1, pos as usize)
2560    }
2561
2562    pub(crate) fn decode_stage_hidden_from_host_with_timing(
2563        &mut self,
2564        cache_id: &str,
2565        hidden: &[f32],
2566        pos: u32,
2567    ) -> (Vec<f32>, LlamaStageHiddenBridgeTiming) {
2568        self.stage_hidden_from_host_with_timing(cache_id, hidden, 1, pos as usize)
2569    }
2570
2571    /// Apply final norm + lm_head to one hidden row. Used by the last pipeline
2572    /// stage after local transformer layers have produced the final hidden
2573    /// state for sampling.
2574    pub fn logits_from_hidden(&mut self, hidden: &[f32]) -> Vec<f32> {
2575        let h = self.cfg.hidden_size;
2576        let vocab = self.cfg.vocab_size;
2577        assert_eq!(
2578            hidden.len(),
2579            h,
2580            "hidden length {} != hidden_size {}",
2581            hidden.len(),
2582            h
2583        );
2584        self.ensure_scratch(1);
2585
2586        let mut ctx = B::new_context();
2587        let hidden_buf = B::from_slice(hidden);
2588        B::rms_norm(
2589            &mut ctx,
2590            &hidden_buf,
2591            &self.final_norm_w,
2592            self.cfg.rms_norm_eps,
2593            &mut self.scratch.last_normed,
2594            1,
2595            h,
2596        );
2597        let lm_head = self
2598            .lm_head
2599            .as_ref()
2600            .expect("logits_from_hidden called on stage without lm_head");
2601        lm_head.forward(
2602            &mut ctx,
2603            &self.scratch.last_normed,
2604            &mut self.scratch.logits,
2605            1,
2606        );
2607        B::sync(&mut ctx);
2608        B::to_vec(&self.scratch.logits, vocab)
2609    }
2610
2611    /// Prefill with pre-computed embeddings instead of token IDs.
2612    ///
2613    /// Used by models that embed inputs outside the LLM (e.g. Qwen3-TTS
2614    /// mixes text-embedding + codec-embedding before feeding the LM).
2615    /// Skips `final_norm` + `lm_head`; returns the last position's pre-norm
2616    /// hidden state. Caller applies its own output head.
2617    ///
2618    /// `embeds` is row-major `[seq_len * hidden_size]`, f32.
2619    pub fn prefill_from_embeds(
2620        &mut self,
2621        cache_id: &str,
2622        embeds: &[f32],
2623        seq_len: usize,
2624    ) -> Vec<f32> {
2625        let h = self.cfg.hidden_size;
2626        assert_eq!(
2627            embeds.len(),
2628            seq_len * h,
2629            "embeds length {} != seq_len * hidden_size {}",
2630            embeds.len(),
2631            seq_len * h
2632        );
2633        assert!(seq_len > 0, "prefill_from_embeds called with zero length");
2634
2635        self.ensure_scratch(seq_len);
2636        self.ensure_kv(cache_id);
2637
2638        let mut ctx = B::new_context();
2639        let mut residual = self
2640            .scratch
2641            .residual
2642            .take()
2643            .expect("scratch residual missing (previous call didn't restore)");
2644
2645        // Upload embeds → residual[0 .. seq_len*h].
2646        let embed_buf = B::from_slice(embeds);
2647        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2648
2649        for li in self.local_layer_indices() {
2650            self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
2651        }
2652
2653        B::copy_slice(
2654            &mut ctx,
2655            &residual,
2656            (seq_len - 1) * h,
2657            &mut self.scratch.last_hidden,
2658            0,
2659            h,
2660        );
2661        B::sync(&mut ctx);
2662        self.scratch.residual = Some(residual);
2663        B::to_vec(&self.scratch.last_hidden, h)
2664    }
2665
2666    /// Decode with a single pre-computed embedding (shape `[hidden]`).
2667    /// Returns the pre-norm hidden state for the position `pos`. Caller
2668    /// applies final norm + its own output head.
2669    pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
2670        let h = self.cfg.hidden_size;
2671        assert_eq!(
2672            embed.len(),
2673            h,
2674            "embed length {} != hidden_size {}",
2675            embed.len(),
2676            h
2677        );
2678
2679        self.ensure_scratch(1);
2680        self.ensure_kv(cache_id);
2681
2682        let mut ctx = B::new_context();
2683        let mut residual = self
2684            .scratch
2685            .residual
2686            .take()
2687            .expect("scratch residual missing (previous call didn't restore)");
2688
2689        let embed_buf = B::from_slice(embed);
2690        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2691
2692        for li in self.local_layer_indices() {
2693            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2694        }
2695
2696        B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
2697        B::sync(&mut ctx);
2698        self.scratch.residual = Some(residual);
2699        B::to_vec(&self.scratch.last_hidden, h)
2700    }
2701
2702    /// Variant of `prefill_from_embeds` that applies `final_norm` to every
2703    /// position and returns the whole `[seq_len * hidden_size]` vector.
2704    /// Accepts `pos_offset` so callers can continue an existing sequence
2705    /// (e.g. Qwen3-TTS voice-clone: one prefill for the role prefix, a
2706    /// follow-up prefill for the reference-audio ICL block, then
2707    /// autoregressive decoding — all against the same KV cache).
2708    ///
2709    /// Used by TTS where `forward_step` in the candle-based wrapper is
2710    /// expected to return **post-norm all-positions** hidden state so
2711    /// `codec_head` can be applied on candle side.
2712    pub fn prefill_all_post_norm(
2713        &mut self,
2714        cache_id: &str,
2715        embeds: &[f32],
2716        seq_len: usize,
2717        pos_offset: usize,
2718    ) -> Vec<f32> {
2719        let h = self.cfg.hidden_size;
2720        assert_eq!(
2721            embeds.len(),
2722            seq_len * h,
2723            "embeds length {} != seq_len * hidden_size {}",
2724            embeds.len(),
2725            seq_len * h
2726        );
2727        assert!(seq_len > 0);
2728
2729        self.ensure_scratch(seq_len);
2730        self.ensure_kv(cache_id);
2731
2732        let mut ctx = B::new_context();
2733        let mut residual = self
2734            .scratch
2735            .residual
2736            .take()
2737            .expect("scratch residual missing (previous call didn't restore)");
2738
2739        let embed_buf = B::from_slice(embeds);
2740        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
2741
2742        for li in self.local_layer_indices() {
2743            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
2744        }
2745
2746        // Apply final_norm over all seq_len positions → scratch.norm_out.
2747        B::rms_norm(
2748            &mut ctx,
2749            &residual,
2750            &self.final_norm_w,
2751            self.cfg.rms_norm_eps,
2752            &mut self.scratch.norm_out,
2753            seq_len,
2754            h,
2755        );
2756        B::sync(&mut ctx);
2757        self.scratch.residual = Some(residual);
2758        B::to_vec(&self.scratch.norm_out, seq_len * h)
2759    }
2760
2761    /// Decode-side companion to `prefill_all_post_norm`. Runs a single-token
2762    /// decode step at `pos`, applies `final_norm`, and returns the post-norm
2763    /// hidden state `[hidden_size]`.
2764    pub fn decode_post_norm_from_embed(
2765        &mut self,
2766        cache_id: &str,
2767        embed: &[f32],
2768        pos: u32,
2769    ) -> Vec<f32> {
2770        let h = self.cfg.hidden_size;
2771        assert_eq!(embed.len(), h);
2772
2773        self.ensure_scratch(1);
2774        self.ensure_kv(cache_id);
2775
2776        let mut ctx = B::new_context();
2777        let mut residual = self
2778            .scratch
2779            .residual
2780            .take()
2781            .expect("scratch residual missing (previous call didn't restore)");
2782
2783        let embed_buf = B::from_slice(embed);
2784        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
2785
2786        for li in self.local_layer_indices() {
2787            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
2788        }
2789
2790        B::rms_norm(
2791            &mut ctx,
2792            &residual,
2793            &self.final_norm_w,
2794            self.cfg.rms_norm_eps,
2795            &mut self.scratch.last_normed,
2796            1,
2797            h,
2798        );
2799        B::sync(&mut ctx);
2800        self.scratch.residual = Some(residual);
2801        B::to_vec(&self.scratch.last_normed, h)
2802    }
2803}
2804
2805// FP16 DecoderOnlyLLM impl — full path with batched + unified-forward overrides.
2806impl<B: MoeLlmBackend> DecoderOnlyLLM for LlamaFamilyModel<B, KvFp16> {
2807    fn config(&self) -> &LlmRuntimeConfig {
2808        &self.runtime_cfg
2809    }
2810
2811    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
2812        Some(self.prefix_cache_snapshot_json())
2813    }
2814
2815    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
2816        Some(self.lora_metrics_snapshot_json())
2817    }
2818
2819    fn set_lora_adapter_for_cache(
2820        &mut self,
2821        cache_id: &str,
2822        adapter: Option<ActiveLoraAdapter>,
2823    ) -> std::result::Result<(), ferrum_types::FerrumError> {
2824        if let Some(adapter) = adapter {
2825            self.ensure_lora_adapter_loaded(adapter.clone())?;
2826            self.lora_cache_adapters
2827                .insert(cache_id.to_string(), adapter.name);
2828        } else {
2829            self.lora_cache_adapters.remove(cache_id);
2830        }
2831        Ok(())
2832    }
2833
2834    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
2835        self.ensure_scratch(max_tokens);
2836        self.ensure_kv(cache_id);
2837
2838        const WARMUP_CACHE: &str = "__ferrum_warmup__";
2839        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
2840        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
2841            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2842                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2843                if let Some(c0) = caches.first() {
2844                    if !c0.paged_block_indices.is_empty() {
2845                        alloc.free(&c0.paged_block_indices);
2846                    }
2847                }
2848                for c in caches.iter_mut() {
2849                    c.paged_block_indices.clear();
2850                }
2851            }
2852            self.kv_free_pool.push(caches);
2853        }
2854    }
2855
2856    fn kv_capacity(&self) -> usize {
2857        self.runtime_env.kv_capacity_for_model(self.cfg.max_seq_len)
2858    }
2859
2860    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2861        self.prefill_internal(cache_id, tokens)
2862    }
2863
2864    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
2865        self.decode_internal(cache_id, token, pos)
2866    }
2867
2868    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
2869        self.decode_batch_with_full_logits(batch, false)
2870    }
2871
2872    fn decode_batch_with_full_logits(
2873        &mut self,
2874        batch: &[(String, u32, u32)],
2875        force_full_logits: bool,
2876    ) -> Vec<Vec<f32>> {
2877        self.decode_batch_internal_with_full_logits(batch, force_full_logits)
2878    }
2879
2880    fn unified_forward(
2881        &mut self,
2882        items: &[(String, Vec<u32>, usize, bool)],
2883    ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
2884        if items.is_empty() {
2885            return Ok(Vec::new());
2886        }
2887        if self.runtime_env.prefix_cache
2888            && items
2889                .iter()
2890                .any(|(_, tokens, pos_offset, _)| *pos_offset == 0 && tokens.len() > 1)
2891        {
2892            return Err(ferrum_types::FerrumError::unsupported(
2893                "LlamaFamilyModel::unified_forward: fresh prefill with prefix cache enabled \
2894                 routes through prefill_internal so real paged-block KV reuse can probe and \
2895                 register block hashes",
2896            ));
2897        }
2898        if !self.supports_varlen_qkv {
2899            return Err(ferrum_types::FerrumError::unsupported(
2900                "LlamaFamilyModel::unified_forward: backend lacks varlen \
2901                 QKV kernels. Engine will fall back to per-item dispatch.",
2902            ));
2903        }
2904        if items
2905            .iter()
2906            .any(|(cache_id, _, _, _)| self.active_lora_adapter_for_cache(cache_id).is_some())
2907        {
2908            return Err(ferrum_types::FerrumError::unsupported(
2909                "LlamaFamilyModel::unified_forward: active LoRA adapter routes through \
2910                 per-item dispatch until unified LoRA supports row-selective adapters.",
2911            ));
2912        }
2913        self.ensure_kv(&items[0].0);
2914        if self.paged_pools.is_none() {
2915            return Err(ferrum_types::FerrumError::unsupported(
2916                "LlamaFamilyModel::unified_forward: paged KV required; \
2917                 enable via FERRUM_METAL_PAGED_KV=1 (cross-backend env). \
2918                 Engine will fall back to per-item dispatch.",
2919            ));
2920        }
2921        for (cid, _, _, _) in items {
2922            self.ensure_kv(cid);
2923            if !self.kv_caches.contains_key(cid) {
2924                return Err(ferrum_types::FerrumError::resource_exhausted(format!(
2925                    "paged KV pool exhausted for cache_id={cid:?}; back off"
2926                )));
2927            }
2928        }
2929        Ok(self.unified_forward_internal(items))
2930    }
2931
2932    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
2933        LlamaFamilyModel::<B, KvFp16>::forward_verify(self, cache_id, tokens)
2934    }
2935
2936    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
2937        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
2938            for c in caches.iter_mut() {
2939                if new_len < c.len {
2940                    c.len = new_len;
2941                }
2942            }
2943        }
2944        let mut ctx = B::new_context();
2945        B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
2946        self.graph_warmup = 0;
2947        self.graph_capture_failed = false;
2948    }
2949
2950    fn release(&mut self, cache_id: &str) {
2951        let mut ctx = B::new_context();
2952        B::sync(&mut ctx);
2953        self.graph_warmup = 0;
2954        self.graph_capture_failed = false;
2955        self.lora_cache_adapters.remove(cache_id);
2956        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
2957            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
2958                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
2959                if let Some(c0) = caches.first() {
2960                    if !c0.paged_block_indices.is_empty() {
2961                        alloc.free(&c0.paged_block_indices);
2962                    }
2963                }
2964                for c in caches.iter_mut() {
2965                    c.paged_block_indices.clear();
2966                }
2967            }
2968            self.kv_free_pool.push(caches);
2969        }
2970    }
2971
2972    fn reset(&mut self) {
2973        let mut ctx = B::new_context();
2974        B::sync(&mut ctx);
2975        B::reset_all_graphs(&mut ctx);
2976        B::sync(&mut ctx);
2977        self.graph_warmup = 0;
2978        self.graph_capture_failed = false;
2979        self.batched_graph_keys_seen.clear();
2980        self.batched_graph_warmup = 0;
2981        self.batched_graph_failed = false;
2982        self.kv_caches.clear();
2983        self.kv_free_pool.clear();
2984        self.lora_cache_adapters.clear();
2985    }
2986}
2987
2988// INT8 DecoderOnlyLLM impl — minimal: no batched / unified-forward overrides
2989// (default trait impl falls back to per-item decode). PR D will add INT8
2990// batched paths once the kernels stabilize.
2991impl<B: MoeLlmBackend + BackendInt8KvOps> DecoderOnlyLLM for LlamaFamilyModel<B, KvInt8> {
2992    fn config(&self) -> &LlmRuntimeConfig {
2993        &self.runtime_cfg
2994    }
2995
2996    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
2997        Some(self.prefix_cache_snapshot_json())
2998    }
2999
3000    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
3001        Some(self.lora_metrics_snapshot_json())
3002    }
3003
3004    fn set_lora_adapter_for_cache(
3005        &mut self,
3006        cache_id: &str,
3007        adapter: Option<ActiveLoraAdapter>,
3008    ) -> std::result::Result<(), ferrum_types::FerrumError> {
3009        if let Some(adapter) = adapter {
3010            self.ensure_lora_adapter_loaded(adapter.clone())?;
3011            self.lora_cache_adapters
3012                .insert(cache_id.to_string(), adapter.name);
3013        } else {
3014            self.lora_cache_adapters.remove(cache_id);
3015        }
3016        Ok(())
3017    }
3018
3019    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
3020        self.ensure_scratch(max_tokens);
3021        self.ensure_kv(cache_id);
3022
3023        const WARMUP_CACHE: &str = "__ferrum_warmup__";
3024        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
3025        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
3026            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
3027                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
3028                if let Some(c0) = caches.first() {
3029                    if !c0.paged_block_indices.is_empty() {
3030                        alloc.free(&c0.paged_block_indices);
3031                    }
3032                }
3033                for c in caches.iter_mut() {
3034                    c.paged_block_indices.clear();
3035                }
3036            }
3037            self.kv_free_pool.push(caches);
3038        }
3039    }
3040
3041    fn kv_capacity(&self) -> usize {
3042        self.runtime_env.kv_capacity_for_model(self.cfg.max_seq_len)
3043    }
3044
3045    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
3046        self.prefill_internal(cache_id, tokens)
3047    }
3048
3049    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
3050        self.decode_internal(cache_id, token, pos)
3051    }
3052
3053    // decode_batch + unified_forward use trait defaults (per-item fallback).
3054
3055    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
3056        LlamaFamilyModel::<B, KvInt8>::forward_verify(self, cache_id, tokens)
3057    }
3058
3059    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
3060        if let Some(caches) = self.kv_caches.get_mut(cache_id) {
3061            for c in caches.iter_mut() {
3062                if new_len < c.len {
3063                    c.len = new_len;
3064                }
3065            }
3066        }
3067        let mut ctx = B::new_context();
3068        B::reset_graph(&mut ctx, SINGLE_ITEM_GRAPH_KEY);
3069        self.graph_warmup = 0;
3070        self.graph_capture_failed = false;
3071    }
3072
3073    fn release(&mut self, cache_id: &str) {
3074        let mut ctx = B::new_context();
3075        B::sync(&mut ctx);
3076        self.graph_warmup = 0;
3077        self.graph_capture_failed = false;
3078        self.lora_cache_adapters.remove(cache_id);
3079        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
3080            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
3081                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
3082                if let Some(c0) = caches.first() {
3083                    if !c0.paged_block_indices.is_empty() {
3084                        alloc.free(&c0.paged_block_indices);
3085                    }
3086                }
3087                for c in caches.iter_mut() {
3088                    c.paged_block_indices.clear();
3089                }
3090            }
3091            self.kv_free_pool.push(caches);
3092        }
3093    }
3094
3095    fn reset(&mut self) {
3096        let mut ctx = B::new_context();
3097        B::sync(&mut ctx);
3098        B::reset_all_graphs(&mut ctx);
3099        B::sync(&mut ctx);
3100        self.graph_warmup = 0;
3101        self.graph_capture_failed = false;
3102        self.kv_caches.clear();
3103        self.kv_free_pool.clear();
3104        self.lora_cache_adapters.clear();
3105    }
3106}
3107
3108fn build_rope_cache<B: QuantLlmBackend + BackendMoeFused>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
3109    let hd = cfg.head_dim;
3110    let half = hd / 2;
3111    let max = cfg.max_seq_len;
3112    let mut cos = vec![0.0f32; max * half];
3113    let mut sin = vec![0.0f32; max * half];
3114    for pos in 0..max {
3115        for i in 0..half {
3116            let freq = rope_freq(cfg, i);
3117            let angle = pos as f64 * freq;
3118            cos[pos * half + i] = angle.cos() as f32;
3119            sin[pos * half + i] = angle.sin() as f32;
3120        }
3121    }
3122    RopeCache {
3123        cos: B::from_slice(&cos),
3124        sin: B::from_slice(&sin),
3125    }
3126}
3127
3128fn rope_freq(cfg: &LlamaFamilyConfig, pair_idx: usize) -> f64 {
3129    let base_freq = 1.0f64
3130        / cfg
3131            .rope_theta
3132            .powf((2 * pair_idx) as f64 / cfg.head_dim as f64);
3133    match &cfg.rope_scaling {
3134        Some(RopeScalingConfig::Llama3 {
3135            factor,
3136            low_freq_factor,
3137            high_freq_factor,
3138            original_max_position_embeddings,
3139        }) => scale_llama3_rope_freq(
3140            base_freq,
3141            *factor,
3142            *low_freq_factor,
3143            *high_freq_factor,
3144            *original_max_position_embeddings,
3145        ),
3146        None => base_freq,
3147    }
3148}
3149
3150fn scale_llama3_rope_freq(
3151    freq: f64,
3152    factor: f64,
3153    low_freq_factor: f64,
3154    high_freq_factor: f64,
3155    original_max_position_embeddings: f64,
3156) -> f64 {
3157    let wavelen = 2.0 * std::f64::consts::PI / freq;
3158    let low_freq_wavelen = original_max_position_embeddings / low_freq_factor;
3159    let high_freq_wavelen = original_max_position_embeddings / high_freq_factor;
3160    if wavelen < high_freq_wavelen {
3161        freq
3162    } else if wavelen > low_freq_wavelen {
3163        freq / factor
3164    } else {
3165        let smooth = (original_max_position_embeddings / wavelen - low_freq_factor)
3166            / (high_freq_factor - low_freq_factor);
3167        (1.0 - smooth) * freq / factor + smooth * freq
3168    }
3169}
3170
3171#[cfg(test)]
3172mod tests {
3173    use std::sync::Mutex;
3174
3175    use ferrum_kernels::backend::cpu::CpuBackend;
3176    use ferrum_quantization::{DenseLinear, QuantConfig, WeightLoader};
3177    use ferrum_types::Result;
3178
3179    use super::{
3180        load_llama_family_layers, LlamaFamilyConfig, LlamaFamilyLayerStageConfig, LlamaFamilyModel,
3181        LlamaFamilyRuntimeEnv, DEFAULT_KV_CAPACITY,
3182    };
3183
3184    #[derive(Default)]
3185    struct RecordingLoader {
3186        tensors: Mutex<Vec<String>>,
3187        linears: Mutex<Vec<String>>,
3188    }
3189
3190    impl WeightLoader<CpuBackend> for RecordingLoader {
3191        fn load_tensor(&self, name: &str) -> Result<Vec<f32>> {
3192            self.tensors.lock().unwrap().push(name.to_string());
3193            Ok(vec![1.0])
3194        }
3195
3196        fn load_linear(
3197            &self,
3198            name: &str,
3199        ) -> Result<Box<dyn ferrum_quantization::Linear<CpuBackend>>> {
3200            self.linears.lock().unwrap().push(name.to_string());
3201            Ok(Box::new(DenseLinear::<CpuBackend>::from_rows(&[1.0], 1, 1)))
3202        }
3203
3204        fn has_tensor(&self, _name: &str) -> bool {
3205            false
3206        }
3207
3208        fn quant_config(&self) -> Option<&QuantConfig> {
3209            None
3210        }
3211    }
3212
3213    fn test_llama_config(num_layers: usize, has_qk_norm: bool) -> LlamaFamilyConfig {
3214        LlamaFamilyConfig {
3215            hidden_size: 1,
3216            intermediate_size: 1,
3217            num_heads: 1,
3218            num_kv_heads: 1,
3219            head_dim: 1,
3220            num_layers,
3221            vocab_size: 1,
3222            max_seq_len: 1,
3223            rms_norm_eps: 1e-5,
3224            rope_theta: 10_000.0,
3225            rope_scaling: None,
3226            rope_interleaved: false,
3227            has_qk_norm,
3228            sliding_window: 0,
3229        }
3230    }
3231
3232    #[test]
3233    fn llama_family_layer_loader_uses_source_layer_range() {
3234        let cfg = test_llama_config(5, true);
3235        let loader = RecordingLoader::default();
3236
3237        let layers = load_llama_family_layers(&cfg, &loader, 2..4).unwrap();
3238
3239        assert_eq!(layers.len(), 2);
3240        assert!(layers.iter().all(|layer| layer.q_norm_w.is_some()));
3241        assert!(layers.iter().all(|layer| layer.k_norm_w.is_some()));
3242        assert_eq!(
3243            loader.tensors.into_inner().unwrap(),
3244            vec![
3245                "model.layers.2.input_layernorm.weight",
3246                "model.layers.2.post_attention_layernorm.weight",
3247                "model.layers.2.self_attn.q_norm.weight",
3248                "model.layers.2.self_attn.k_norm.weight",
3249                "model.layers.3.input_layernorm.weight",
3250                "model.layers.3.post_attention_layernorm.weight",
3251                "model.layers.3.self_attn.q_norm.weight",
3252                "model.layers.3.self_attn.k_norm.weight",
3253            ]
3254        );
3255        assert_eq!(
3256            loader.linears.into_inner().unwrap(),
3257            vec![
3258                "model.layers.2.self_attn.qkv_proj",
3259                "model.layers.2.self_attn.o_proj",
3260                "model.layers.2.mlp.gate_up_proj",
3261                "model.layers.2.mlp.down_proj",
3262                "model.layers.3.self_attn.qkv_proj",
3263                "model.layers.3.self_attn.o_proj",
3264                "model.layers.3.mlp.gate_up_proj",
3265                "model.layers.3.mlp.down_proj",
3266            ]
3267        );
3268    }
3269
3270    #[test]
3271    fn llama_family_layer_loader_rejects_out_of_bounds_range() {
3272        let cfg = test_llama_config(2, false);
3273        let loader = RecordingLoader::default();
3274
3275        let err = match load_llama_family_layers(&cfg, &loader, 1..3) {
3276            Ok(_) => panic!("expected out-of-bounds layer range to fail"),
3277            Err(err) => err,
3278        };
3279
3280        assert!(
3281            err.to_string().contains("outside model layer count"),
3282            "{err}"
3283        );
3284    }
3285
3286    #[test]
3287    fn llama_family_full_model_records_source_layer_range() {
3288        let cfg = test_llama_config(3, false);
3289        let loader = RecordingLoader::default();
3290        let mut model = LlamaFamilyModel::<CpuBackend>::new(cfg, &loader).unwrap();
3291
3292        assert_eq!(model.source_layer_range(), 0..3);
3293        assert_eq!(model.local_layer_count(), 3);
3294        assert_eq!(model.source_layer_index(2), 2);
3295
3296        model.ensure_kv("test-cache");
3297        assert_eq!(
3298            model.kv_caches["test-cache"].len(),
3299            model.local_layer_count()
3300        );
3301    }
3302
3303    #[test]
3304    fn llama_family_layer_stage_loads_only_requested_weights() {
3305        let cfg = test_llama_config(5, false);
3306        let loader = RecordingLoader::default();
3307
3308        let model = LlamaFamilyModel::<CpuBackend>::new_layer_stage(
3309            cfg,
3310            &loader,
3311            LlamaFamilyLayerStageConfig::pipeline_stage(2..5, false, true),
3312        )
3313        .unwrap();
3314
3315        assert_eq!(model.source_layer_range(), 2..5);
3316        assert_eq!(model.local_layer_count(), 3);
3317        assert!(model.embed.is_none());
3318        assert!(model.lm_head.is_some());
3319        assert_eq!(
3320            loader.tensors.into_inner().unwrap(),
3321            vec![
3322                "model.layers.2.input_layernorm.weight",
3323                "model.layers.2.post_attention_layernorm.weight",
3324                "model.layers.3.input_layernorm.weight",
3325                "model.layers.3.post_attention_layernorm.weight",
3326                "model.layers.4.input_layernorm.weight",
3327                "model.layers.4.post_attention_layernorm.weight",
3328                "model.norm.weight",
3329            ]
3330        );
3331        assert_eq!(
3332            loader.linears.into_inner().unwrap(),
3333            vec![
3334                "model.layers.2.self_attn.qkv_proj",
3335                "model.layers.2.self_attn.o_proj",
3336                "model.layers.2.mlp.gate_up_proj",
3337                "model.layers.2.mlp.down_proj",
3338                "model.layers.3.self_attn.qkv_proj",
3339                "model.layers.3.self_attn.o_proj",
3340                "model.layers.3.mlp.gate_up_proj",
3341                "model.layers.3.mlp.down_proj",
3342                "model.layers.4.self_attn.qkv_proj",
3343                "model.layers.4.self_attn.o_proj",
3344                "model.layers.4.mlp.gate_up_proj",
3345                "model.layers.4.mlp.down_proj",
3346                "model.embed_tokens",
3347            ]
3348        );
3349    }
3350
3351    #[test]
3352    fn llama_family_layer_stage_runs_hidden_forward_bridge() {
3353        let cfg = test_llama_config(1, false);
3354        let loader = RecordingLoader::default();
3355        let mut model = LlamaFamilyModel::<CpuBackend>::new_layer_stage(
3356            cfg,
3357            &loader,
3358            LlamaFamilyLayerStageConfig::pipeline_stage(0..1, false, false),
3359        )
3360        .unwrap();
3361
3362        let hidden = model.prefill_stage_hidden_from_host("stage-cache", &[1.0], 1, 0);
3363
3364        assert_eq!(hidden.len(), 1);
3365        assert!(hidden[0].is_finite());
3366        assert_eq!(
3367            model.kv_caches["stage-cache"].len(),
3368            model.local_layer_count()
3369        );
3370    }
3371
3372    #[test]
3373    fn llama_family_last_stage_projects_hidden_to_logits() {
3374        let cfg = test_llama_config(1, false);
3375        let loader = RecordingLoader::default();
3376        let mut model = LlamaFamilyModel::<CpuBackend>::new_layer_stage(
3377            cfg,
3378            &loader,
3379            LlamaFamilyLayerStageConfig::pipeline_stage(0..1, false, true),
3380        )
3381        .unwrap();
3382
3383        let logits = model.logits_from_hidden(&[2.0]);
3384
3385        assert_eq!(logits.len(), 1);
3386        assert!(logits[0].is_finite());
3387    }
3388
3389    #[test]
3390    fn llama_family_runtime_env_parses_startup_knobs() {
3391        let env = LlamaFamilyRuntimeEnv::from_env_vars([
3392            ("FERRUM_KV_CAPACITY", "4096"),
3393            ("FERRUM_METAL_PAGED_KV", "0"),
3394            ("FERRUM_PAGED_MAX_SEQS", "64"),
3395            ("FERRUM_DECODE_OP_PROFILE", "0"),
3396            ("FERRUM_PREFILL_OP_PROFILE", ""),
3397            ("FERRUM_CUDA_GRAPH", ""),
3398            ("FERRUM_DECODE_LAYER_PROFILE", "false"),
3399        ]);
3400
3401        assert_eq!(env.kv_capacity, Some(4096));
3402        assert_eq!(env.metal_paged_kv, Some(false));
3403        assert_eq!(env.paged_max_seqs, 64);
3404        assert!(env.decode_op_profile);
3405        assert!(env.prefill_op_profile);
3406        assert!(env.cuda_graph);
3407        assert!(env.decode_layer_profile);
3408        assert_eq!(env.kv_capacity_for_model(2048), 2048);
3409    }
3410
3411    #[test]
3412    fn llama_family_runtime_env_uses_defaults_for_invalid_values() {
3413        let env = LlamaFamilyRuntimeEnv::from_env_vars([
3414            ("FERRUM_KV_CAPACITY", "bad"),
3415            ("FERRUM_PAGED_MAX_SEQS", "bad"),
3416            ("FERRUM_METAL_PAGED_KV", "1"),
3417        ]);
3418
3419        assert_eq!(env.kv_capacity, None);
3420        assert_eq!(env.metal_paged_kv, Some(true));
3421        assert_eq!(env.paged_max_seqs, 32);
3422        assert_eq!(
3423            env.kv_capacity_for_model(DEFAULT_KV_CAPACITY * 2),
3424            DEFAULT_KV_CAPACITY
3425        );
3426    }
3427}