Skip to main content

ferrum_models/models/
llama_family.rs

1//! Llama-family decoder model as explicit code.
2//!
3//! Covers all "standard Llama-style GQA + SwiGLU + RoPE" decoders:
4//!   - Llama / Llama-2 / Llama-3  (no QK-norm)
5//!   - Qwen2 / Qwen2.5            (no QK-norm, structurally Llama)
6//!   - Qwen3                      (QK-norm per head, larger rope_theta)
7//!   - Mistral                    (sliding-window attention — not yet
8//!                                 supported on the forward path; will
9//!                                 require an `AttnConfig.sliding_window`
10//!                                 field + shader support in Phase D)
11//!
12//! Variant differences are controlled by `LlamaFamilyConfig::has_qk_norm`
13//! and RoPE theta. Weight loading accepts both fused (`qkv_proj`,
14//! `gate_up_proj`) and split (`q_proj`+`k_proj`+`v_proj`,
15//! `gate_proj`+`up_proj`) projection layouts — the loader (e.g.
16//! `CandleVarBuilderLoader`) fuses split weights on load so model code
17//! sees a uniform `qkv_proj` / `gate_up_proj` Linear.
18
19use std::collections::HashMap;
20
21use ferrum_kernels::backend::{Backend, KvCache};
22use ferrum_quantization::{Linear, WeightLoader};
23use ferrum_types::Result;
24
25use crate::common::{DecoderOnlyLLM, LlmRuntimeConfig};
26
27/// Full Qwen3 architecture config (everything the model code needs, not just
28/// the engine-facing subset in `LlmRuntimeConfig`).
29#[derive(Clone, Debug)]
30pub struct LlamaFamilyConfig {
31    pub hidden_size: usize,
32    pub intermediate_size: usize,
33    pub num_heads: usize,
34    pub num_kv_heads: usize,
35    pub head_dim: usize,
36    pub num_layers: usize,
37    pub vocab_size: usize,
38    pub max_seq_len: usize,
39    pub rms_norm_eps: f32,
40    pub rope_theta: f64,
41    /// Whether the checkpoint has `q_norm` / `k_norm` per layer. All known
42    /// Qwen3 checkpoints do; some derivatives may strip them.
43    pub has_qk_norm: bool,
44    /// Sliding-window attention size. `0` disables (full causal).
45    /// Mistral v0.1 sets 4096; Mistral v0.2+ removed the limitation (0).
46    pub sliding_window: usize,
47}
48
49impl LlamaFamilyConfig {
50    pub fn to_runtime(&self) -> LlmRuntimeConfig {
51        LlmRuntimeConfig {
52            hidden_size: self.hidden_size,
53            num_layers: self.num_layers,
54            num_kv_heads: self.num_kv_heads,
55            head_dim: self.head_dim,
56            vocab_size: self.vocab_size,
57            max_seq_len: self.max_seq_len,
58        }
59    }
60
61    /// Build config from a `ModelDefinition`, shared field extraction.
62    /// Variant-specific constructors below set `has_qk_norm` and fall back
63    /// to different `rope_theta` defaults when the checkpoint doesn't set one.
64    fn from_def_base(def: &crate::definition::ModelDefinition) -> LlamaFamilyConfigBase {
65        let num_kv_heads = def.num_key_value_heads.unwrap_or(def.num_attention_heads);
66        let head_dim = def
67            .extra_params
68            .get("head_dim")
69            .and_then(|v| v.as_u64())
70            .map(|v| v as usize)
71            .unwrap_or(def.hidden_size / def.num_attention_heads);
72        // Mistral / Gemma: "sliding_window" may be null (v0.2+) or a positive
73        // integer (v0.1). Non-null value passes through; missing/null → 0.
74        let sliding_window = def
75            .extra_params
76            .get("sliding_window")
77            .and_then(|v| v.as_u64())
78            .map(|v| v as usize)
79            .unwrap_or(0);
80
81        LlamaFamilyConfigBase {
82            hidden_size: def.hidden_size,
83            intermediate_size: def.intermediate_size,
84            num_heads: def.num_attention_heads,
85            num_kv_heads,
86            head_dim,
87            num_layers: def.num_hidden_layers,
88            vocab_size: def.vocab_size,
89            max_seq_len: def.max_position_embeddings,
90            rms_norm_eps: def.norm_eps as f32,
91            rope_theta_opt: def.rope_theta,
92            sliding_window,
93        }
94    }
95
96    fn from_base(b: LlamaFamilyConfigBase, rope_default: f64, has_qk_norm: bool) -> Self {
97        Self {
98            hidden_size: b.hidden_size,
99            intermediate_size: b.intermediate_size,
100            num_heads: b.num_heads,
101            num_kv_heads: b.num_kv_heads,
102            head_dim: b.head_dim,
103            num_layers: b.num_layers,
104            vocab_size: b.vocab_size,
105            max_seq_len: b.max_seq_len,
106            rms_norm_eps: b.rms_norm_eps,
107            rope_theta: b.rope_theta_opt.unwrap_or(rope_default),
108            has_qk_norm,
109            sliding_window: b.sliding_window,
110        }
111    }
112
113    /// Qwen3: QK-norm on, rope_theta default 1e6.
114    pub fn qwen3_from_def(def: &crate::definition::ModelDefinition) -> Self {
115        Self::from_base(Self::from_def_base(def), 1_000_000.0, true)
116    }
117
118    /// Llama / Llama-2 / Llama-3: no QK-norm; rope_theta varies by version
119    /// (10k for Llama-2, 500k for Llama-3.1+) — use the checkpoint value or
120    /// fall back to the most common modern value.
121    pub fn llama_from_def(def: &crate::definition::ModelDefinition) -> Self {
122        Self::from_base(Self::from_def_base(def), 500_000.0, false)
123    }
124
125    /// Qwen2 / Qwen2.5: structurally Llama; no QK-norm; rope_theta default 1e6.
126    pub fn qwen2_from_def(def: &crate::definition::ModelDefinition) -> Self {
127        Self::from_base(Self::from_def_base(def), 1_000_000.0, false)
128    }
129
130    /// Mistral: no QK-norm; `rope_theta` commonly 10_000 (v0.1 / v0.2).
131    /// Picks up `sliding_window` from the checkpoint's config.json
132    /// (Mistral v0.1: 4096; Mistral v0.2+: 0 / null).
133    pub fn mistral_from_def(def: &crate::definition::ModelDefinition) -> Self {
134        Self::from_base(Self::from_def_base(def), 10_000.0, false)
135    }
136}
137
138struct LlamaFamilyConfigBase {
139    hidden_size: usize,
140    intermediate_size: usize,
141    num_heads: usize,
142    num_kv_heads: usize,
143    head_dim: usize,
144    num_layers: usize,
145    vocab_size: usize,
146    max_seq_len: usize,
147    rms_norm_eps: f32,
148    rope_theta_opt: Option<f64>,
149    sliding_window: usize,
150}
151
152/// Per-layer weights. `Box<dyn Linear<B>>` means each projection can be
153/// Dense / GPTQ / AWQ / GGUF without the surrounding code caring.
154pub struct LlamaFamilyLayer<B: Backend> {
155    pub input_ln_w: B::Buffer,
156    pub qkv_proj: Box<dyn Linear<B>>,
157    /// QK-norm weight per head: `[head_dim]`. Optional for non-Qwen3 derivatives.
158    pub q_norm_w: Option<B::Buffer>,
159    pub k_norm_w: Option<B::Buffer>,
160    pub o_proj: Box<dyn Linear<B>>,
161    pub post_ln_w: B::Buffer,
162    pub gate_up_proj: Box<dyn Linear<B>>,
163    pub down_proj: Box<dyn Linear<B>>,
164}
165
166/// Precomputed RoPE cos/sin tables (shape `[max_seq, head_dim / 2]` each).
167pub struct RopeCache<B: Backend> {
168    pub cos: B::Buffer,
169    pub sin: B::Buffer,
170}
171
172/// Reusable per-layer scratch buffers sized for `max_tokens` tokens of a
173/// single forward pass (prefill or decode step).
174///
175/// Sized lazily on first use so tiny decode steps don't pay for prefill-sized
176/// buffers. Grows monotonically when a larger prefill arrives.
177pub struct LlamaFamilyScratch<B: Backend> {
178    /// Residual stream — wrapped in Option so decode_internal can
179    /// `.take()` it without needing an alloc placeholder.
180    ///
181    /// Why this matters for graph capture: the old pattern was
182    /// `mem::replace(&mut scratch.residual, B::alloc(1))` which creates a
183    /// 1-element buffer at every decode step. When graph capture is on,
184    /// that alloc-during-capture + drop-after-capture pair surfaces as
185    /// cuMemFreeAsync(INVALID_VALUE) because the free tries to release a
186    /// pointer the captured graph may still reference. Option::take leaves
187    /// None and moves the real buffer into a local — no spurious alloc.
188    pub residual: Option<B::Buffer>,
189    pub norm_out: B::Buffer,
190    pub qkv_out: B::Buffer,
191    // ── Per-item scratch for batched decode path ──────────────────────
192    // decode_batch_internal runs tokens=M batched ops for the GEMM-heavy
193    // half (norm, qkv_proj, split_qkv, o_proj, post_norm, gate_up, silu,
194    // down, residual_add) but must loop per-item for rope + KV append +
195    // attention (each item has its own KV cache at a different kv_len).
196    // These single-item buffers hold item i's slice during that loop.
197    /// Item-scope q_buf slice, sized `q_dim`.
198    pub q_single: B::Buffer,
199    pub k_single: B::Buffer,
200    pub v_single: B::Buffer,
201    pub q_head_major_single: B::Buffer,
202    pub k_head_major_single: B::Buffer,
203    pub v_head_major_single: B::Buffer,
204    pub attn_head_major_single: B::Buffer,
205    pub attn_flat_single: B::Buffer,
206    /// Batched logits output, sized `max_tokens * vocab_size`. Used only
207    /// in decode_batch; prefill/single-decode use the regular `logits`.
208    pub batch_logits: B::Buffer,
209    /// Token-major Q/K/V right after `split_qkv`. Stride: heads * hd per row.
210    pub q_buf: B::Buffer,
211    pub k_buf: B::Buffer,
212    pub v_buf: B::Buffer,
213    /// Head-major Q produced by `qk_norm_rope` — fed into `flash_attention`.
214    pub q_head_major: B::Buffer,
215    /// Head-major K/V staging — produced by `qk_norm_rope`, consumed by
216    /// `kv_cache_append_head_major` (no reuse after append).
217    pub k_head_major: B::Buffer,
218    pub v_head_major: B::Buffer,
219    /// Head-major attention output from `flash_attention`.
220    pub attn_head_major_out: B::Buffer,
221    /// Token-major attention output after `transpose_head_to_token`.
222    pub attn_flat: B::Buffer,
223    pub o_proj_out: B::Buffer,
224    pub gate_up_out: B::Buffer,
225    pub silu_out: B::Buffer,
226    pub mlp_out: B::Buffer,
227    /// Last token's hidden state (`[h]`). For prefill this is populated via
228    /// `copy_slice(residual, (seq_len-1)*h, ..)`; for decode `residual` already
229    /// holds only 1 row so `last_hidden` is unused on that path.
230    pub last_hidden: B::Buffer,
231    /// Final-norm output for the last token (`[h]`).
232    pub last_normed: B::Buffer,
233    /// lm_head logits (`[vocab]`).
234    pub logits: B::Buffer,
235    /// The max tokens-per-step this scratch has been sized for.
236    pub max_tokens: usize,
237}
238
239impl<B: Backend> LlamaFamilyScratch<B> {
240    fn alloc(cfg: &LlamaFamilyConfig, max_tokens: usize) -> Self {
241        let h = cfg.hidden_size;
242        let im = cfg.intermediate_size;
243        let q_dim = cfg.num_heads * cfg.head_dim;
244        let kv_dim = cfg.num_kv_heads * cfg.head_dim;
245        let qkv_dim = q_dim + 2 * kv_dim;
246        let t = max_tokens;
247        Self {
248            residual: Some(B::alloc(t * h)),
249            norm_out: B::alloc(t * h),
250            qkv_out: B::alloc(t * qkv_dim),
251            q_buf: B::alloc(t * q_dim),
252            k_buf: B::alloc(t * kv_dim),
253            v_buf: B::alloc(t * kv_dim),
254            q_head_major: B::alloc(cfg.num_heads * t * cfg.head_dim),
255            k_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
256            v_head_major: B::alloc(cfg.num_kv_heads * t * cfg.head_dim),
257            attn_head_major_out: B::alloc(cfg.num_heads * t * cfg.head_dim),
258            attn_flat: B::alloc(t * q_dim),
259            o_proj_out: B::alloc(t * h),
260            gate_up_out: B::alloc(t * 2 * im),
261            silu_out: B::alloc(t * im),
262            mlp_out: B::alloc(t * h),
263            last_hidden: B::alloc(h),
264            last_normed: B::alloc(h),
265            logits: B::alloc(cfg.vocab_size),
266            q_single: B::alloc(q_dim),
267            k_single: B::alloc(kv_dim),
268            v_single: B::alloc(kv_dim),
269            q_head_major_single: B::alloc(q_dim),
270            k_head_major_single: B::alloc(kv_dim),
271            v_head_major_single: B::alloc(kv_dim),
272            attn_head_major_single: B::alloc(q_dim),
273            attn_flat_single: B::alloc(q_dim),
274            batch_logits: B::alloc(t * cfg.vocab_size),
275            max_tokens: t,
276        }
277    }
278}
279
280/// Qwen3 model — decoder-only LLM, one per (backend, weights) combination.
281///
282/// Holds all parameters, scratch space, RoPE cache, and per-sequence KV caches.
283pub struct LlamaFamilyModel<B: Backend> {
284    pub cfg: LlamaFamilyConfig,
285    pub runtime_cfg: LlmRuntimeConfig,
286
287    /// Token embedding table. `None` for backbone-only models (e.g. the
288    /// Qwen3-TTS Talker, which embeds inputs externally and feeds via
289    /// `prefill_from_embeds`).
290    pub embed: Option<B::Buffer>,
291    pub layers: Vec<LlamaFamilyLayer<B>>,
292    pub final_norm_w: B::Buffer,
293    /// LM output head. `None` for backbone-only models.
294    pub lm_head: Option<Box<dyn Linear<B>>>,
295
296    pub rope: RopeCache<B>,
297    pub scratch: LlamaFamilyScratch<B>,
298
299    /// Per-sequence KV caches, one `Vec<KvCache<B>>` of length `num_layers`.
300    pub kv_caches: HashMap<String, Vec<KvCache<B>>>,
301    /// Free pool of pre-allocated KV cache slots. Released caches return
302    /// here instead of being dropped, so their device pointers stay valid
303    /// across requests — critical for graph capture (pointers baked into
304    /// the captured graph would otherwise dangle).
305    kv_free_pool: Vec<Vec<KvCache<B>>>,
306
307    // ── Graph capture state (CUDA only; harmless no-op on other backends) ──
308    /// Count of eager decode steps run so far. After `GRAPH_WARMUP`, the
309    /// next step captures the decode flow as a graph.
310    graph_warmup: usize,
311    /// True if capture was attempted but failed (e.g. backend doesn't
312    /// support graph capture). Stops further attempts, falls back to eager.
313    graph_capture_failed: bool,
314}
315
316impl<B: Backend> LlamaFamilyModel<B> {
317    /// Build a Qwen3 model from weights provided by the loader.
318    ///
319    /// The loader decides per-projection whether to instantiate DenseLinear,
320    /// GptqLinear, etc. — this code doesn't care.
321    pub fn new(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
322        // Invalidate any graph from a previously-loaded model. The captured
323        // graph references the old model's scratch buffers; a fresh model
324        // gets fresh scratch, so reusing the graph would read/write freed
325        // pointers. Matters for test suites where multiple models coexist.
326        {
327            let mut ctx = B::new_context();
328            B::reset_graph(&mut ctx);
329        }
330        let rope = build_rope_cache::<B>(&cfg);
331        let scratch = LlamaFamilyScratch::alloc(&cfg, 1); // decode-sized; prefill resizes
332
333        // Embedding: plain tensor (no projection math, just lookup).
334        let embed = loader.load_tensor("model.embed_tokens.weight")?;
335
336        // Per-layer weights.
337        let mut layers = Vec::with_capacity(cfg.num_layers);
338        for li in 0..cfg.num_layers {
339            let prefix = format!("model.layers.{li}");
340            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
341            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
342            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
343            let post_ln_w =
344                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
345            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
346            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
347
348            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
349                let q = loader
350                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
351                    .ok();
352                let k = loader
353                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
354                    .ok();
355                (q, k)
356            } else {
357                (None, None)
358            };
359
360            layers.push(LlamaFamilyLayer {
361                input_ln_w,
362                qkv_proj,
363                q_norm_w,
364                k_norm_w,
365                o_proj,
366                post_ln_w,
367                gate_up_proj,
368                down_proj,
369            });
370        }
371
372        let final_norm_w = loader.load_tensor("model.norm.weight")?;
373
374        // LM head: either dedicated `lm_head.weight` or tied to embedding.
375        // Many models (Qwen3-4B, Llama-3.2-1B, some Qwen2.5) use TIED
376        // embeddings — lm_head shares weights with model.embed_tokens. When
377        // no dedicated lm_head tensor exists, re-load the embed tensor as a
378        // DenseLinear. This duplicates the buffer (memory cost = vocab*h*2
379        // bytes, e.g. ~770MB for Qwen3-4B) but keeps the Linear trait's
380        // owned-weights invariant. Sharing via Arc is a future optimisation.
381        let lm_head = if loader.has_tensor("lm_head.weight") {
382            loader.load_linear("lm_head")?
383        } else {
384            tracing::info!(
385                "LlamaFamilyModel: tied embeddings — loading model.embed_tokens.weight as lm_head"
386            );
387            let as_linear = loader.load_linear("model.embed_tokens")?;
388            // Sanity check: shape must be [vocab, hidden].
389            if as_linear.out_features() != cfg.vocab_size
390                || as_linear.in_features() != cfg.hidden_size
391            {
392                return Err(ferrum_types::FerrumError::model(format!(
393                    "tied embed shape mismatch: got [{}, {}], expected [{}, {}]",
394                    as_linear.out_features(),
395                    as_linear.in_features(),
396                    cfg.vocab_size,
397                    cfg.hidden_size
398                )));
399            }
400            as_linear
401        };
402
403        let runtime_cfg = cfg.to_runtime();
404        Ok(Self {
405            cfg,
406            runtime_cfg,
407            embed: Some(embed),
408            layers,
409            final_norm_w,
410            lm_head: Some(lm_head),
411            rope,
412            scratch,
413            kv_caches: HashMap::new(),
414            kv_free_pool: Vec::new(),
415            graph_warmup: 0,
416            graph_capture_failed: false,
417        })
418    }
419
420    /// Build a backbone-only Qwen3 transformer stack (no embed, no lm_head).
421    ///
422    /// Intended for composing the transformer inside a larger model where
423    /// embedding and output-head logic differs from the standard LLM path —
424    /// e.g. Qwen3-TTS Talker uses dual text/codec embeddings with a projection
425    /// MLP, and a codec_head output. The caller drives forward via
426    /// `prefill_from_embeds` / `decode_from_embed`.
427    ///
428    /// Loader must provide: per-layer weights under `model.layers.{i}.*` and
429    /// the final `model.norm.weight`. `model.embed_tokens` and `lm_head`
430    /// are NOT read.
431    pub fn new_backbone_only(cfg: LlamaFamilyConfig, loader: &dyn WeightLoader<B>) -> Result<Self> {
432        // See `new` — invalidate stale graph referring to prior model's scratch.
433        {
434            let mut ctx = B::new_context();
435            B::reset_graph(&mut ctx);
436        }
437        let rope = build_rope_cache::<B>(&cfg);
438        let scratch = LlamaFamilyScratch::alloc(&cfg, 1);
439
440        let mut layers = Vec::with_capacity(cfg.num_layers);
441        for li in 0..cfg.num_layers {
442            let prefix = format!("model.layers.{li}");
443            let input_ln_w = loader.load_tensor(&format!("{prefix}.input_layernorm.weight"))?;
444            let qkv_proj = loader.load_linear(&format!("{prefix}.self_attn.qkv_proj"))?;
445            let o_proj = loader.load_linear(&format!("{prefix}.self_attn.o_proj"))?;
446            let post_ln_w =
447                loader.load_tensor(&format!("{prefix}.post_attention_layernorm.weight"))?;
448            let gate_up_proj = loader.load_linear(&format!("{prefix}.mlp.gate_up_proj"))?;
449            let down_proj = loader.load_linear(&format!("{prefix}.mlp.down_proj"))?;
450
451            let (q_norm_w, k_norm_w) = if cfg.has_qk_norm {
452                let q = loader
453                    .load_tensor(&format!("{prefix}.self_attn.q_norm.weight"))
454                    .ok();
455                let k = loader
456                    .load_tensor(&format!("{prefix}.self_attn.k_norm.weight"))
457                    .ok();
458                (q, k)
459            } else {
460                (None, None)
461            };
462
463            layers.push(LlamaFamilyLayer {
464                input_ln_w,
465                qkv_proj,
466                q_norm_w,
467                k_norm_w,
468                o_proj,
469                post_ln_w,
470                gate_up_proj,
471                down_proj,
472            });
473        }
474
475        let final_norm_w = loader.load_tensor("model.norm.weight")?;
476
477        let runtime_cfg = cfg.to_runtime();
478        Ok(Self {
479            cfg,
480            runtime_cfg,
481            embed: None,
482            layers,
483            final_norm_w,
484            lm_head: None,
485            rope,
486            scratch,
487            kv_caches: HashMap::new(),
488            kv_free_pool: Vec::new(),
489            graph_warmup: 0,
490            graph_capture_failed: false,
491        })
492    }
493
494    /// Grow scratch buffers if `tokens` exceeds the current sizing.
495    pub(crate) fn ensure_scratch(&mut self, tokens: usize) {
496        if self.scratch.max_tokens < tokens {
497            // Any captured decode graph holds pointers to the old scratch
498            // buffers; those are about to be freed. Invalidate first so the
499            // next decode falls back to eager + re-captures with fresh ptrs.
500            // Critical for multi-turn chat (turn N+1's prefill may grow scratch).
501            {
502                let mut ctx = B::new_context();
503                B::reset_graph(&mut ctx);
504            }
505            self.scratch = LlamaFamilyScratch::alloc(&self.cfg, tokens);
506            self.graph_warmup = 0;
507            self.graph_capture_failed = false;
508        }
509    }
510
511    /// Ensure per-layer KV caches exist for `cache_id`, pre-allocated to
512    /// `max_seq_len` slots per head. Enables the in-place
513    /// `kv_cache_append_head_major` path — no realloc per layer.
514    pub(crate) fn ensure_kv(&mut self, cache_id: &str) {
515        if self.kv_caches.contains_key(cache_id) {
516            return;
517        }
518        let nkv = self.cfg.num_kv_heads;
519        let hd = self.cfg.head_dim;
520        let max = self.cfg.max_seq_len;
521
522        // Try pool first — reused buffers have stable device pointers,
523        // so a captured decode graph can be replayed for this request too.
524        let mut caches = self.kv_free_pool.pop().unwrap_or_else(|| {
525            (0..self.cfg.num_layers)
526                .map(|_| KvCache {
527                    k: B::alloc(nkv * max * hd),
528                    v: B::alloc(nkv * max * hd),
529                    len: 0,
530                    capacity: max,
531                    num_kv_heads: nkv,
532                    head_dim: hd,
533                })
534                .collect()
535        });
536        // Reset logical length; buffers stay. No need to zero the memory —
537        // the kv_cache_append writes new K/V in place, and attention only
538        // reads up to `cache_len`.
539        for c in caches.iter_mut() {
540            c.len = 0;
541        }
542        self.kv_caches.insert(cache_id.to_string(), caches);
543    }
544
545    /// Run one transformer layer. Mutates `residual` in place.
546    ///
547    /// `pos_offset` is the absolute position of token 0 in this batch
548    /// (decode: `pos`; prefill: 0). `tokens` is the batch size.
549    #[allow(clippy::too_many_arguments)]
550    pub(crate) fn forward_layer(
551        &mut self,
552        ctx: &mut B::Context,
553        li: usize,
554        cache_id: &str,
555        residual: &mut B::Buffer,
556        pos_offset: usize,
557        tokens: usize,
558    ) {
559        let layer = &self.layers[li];
560        let cfg = &self.cfg;
561        let h = cfg.hidden_size;
562        let nh = cfg.num_heads;
563        let nkv = cfg.num_kv_heads;
564        let hd = cfg.head_dim;
565        let im = cfg.intermediate_size;
566        let eps = cfg.rms_norm_eps;
567        let q_dim = nh * hd;
568        let kv_dim = nkv * hd;
569
570        // 1. Input RMSNorm
571        B::rms_norm(
572            ctx,
573            residual,
574            &layer.input_ln_w,
575            eps,
576            &mut self.scratch.norm_out,
577            tokens,
578            h,
579        );
580
581        // 2. Fused QKV projection (Linear dispatches to Dense/GPTQ/AWQ/GGUF)
582        layer.qkv_proj.forward(
583            ctx,
584            &self.scratch.norm_out,
585            &mut self.scratch.qkv_out,
586            tokens,
587        );
588
589        // 3. Split fused QKV → token-major Q/K/V
590        B::split_qkv(
591            ctx,
592            &self.scratch.qkv_out,
593            &mut self.scratch.q_buf,
594            &mut self.scratch.k_buf,
595            &mut self.scratch.v_buf,
596            tokens,
597            q_dim,
598            kv_dim,
599        );
600
601        // 4. Fused QK-norm + RoPE + transpose to head-major
602        //    Qwen3: mode=1 (norm + rope). Non-QK-norm variants: mode=2 (rope only).
603        //    V always uses mode=0 (transpose only).
604        let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
605        let dummy = &layer.input_ln_w;
606        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy);
607        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy);
608
609        B::qk_norm_rope(
610            ctx,
611            &self.scratch.q_buf,
612            q_norm_w,
613            &self.rope.cos,
614            &self.rope.sin,
615            &mut self.scratch.q_head_major,
616            tokens,
617            nh,
618            hd,
619            pos_offset,
620            eps,
621            qk_mode,
622        );
623        B::qk_norm_rope(
624            ctx,
625            &self.scratch.k_buf,
626            k_norm_w,
627            &self.rope.cos,
628            &self.rope.sin,
629            &mut self.scratch.k_head_major,
630            tokens,
631            nkv,
632            hd,
633            pos_offset,
634            eps,
635            qk_mode,
636        );
637        B::qk_norm_rope(
638            ctx,
639            &self.scratch.v_buf,
640            dummy, // unused in mode 0
641            &self.rope.cos,
642            &self.rope.sin,
643            &mut self.scratch.v_head_major,
644            tokens,
645            nkv,
646            hd,
647            pos_offset,
648            eps,
649            0, // transpose only
650        );
651
652        // 5. Append K/V to pre-allocated head-major cache
653        let caches = self
654            .kv_caches
655            .get_mut(cache_id)
656            .expect("ensure_kv must be called before forward_layer");
657        let cache = &mut caches[li];
658        B::kv_cache_append_head_major(
659            ctx,
660            &mut cache.k,
661            &mut cache.v,
662            cache.len,
663            cache.capacity,
664            &self.scratch.k_head_major,
665            &self.scratch.v_head_major,
666            tokens,
667            nkv,
668            hd,
669        );
670        cache.len += tokens;
671        let kv_len = cache.len;
672        let kv_stride = cache.capacity;
673
674        // 6. Flash attention over strided cache.
675        //    `causal` is always true for decoder-only LLMs — every query must
676        //    mask out future tokens. (The `tokens > 1` heuristic from the old
677        //    path only worked because single-token decode trivially "attends"
678        //    to one position.) Sliding-window models (Mistral v0.1) narrow
679        //    the lower bound via `sliding_window`.
680        let attn_cfg = ferrum_kernels::backend::AttnConfig {
681            num_heads: nh,
682            num_kv_heads: nkv,
683            head_dim: hd,
684            causal: true,
685            scale: 1.0 / (hd as f32).sqrt(),
686            kv_seq_stride: kv_stride,
687            sliding_window: cfg.sliding_window,
688        };
689        B::flash_attention(
690            ctx,
691            &self.scratch.q_head_major,
692            &cache.k,
693            &cache.v,
694            &mut self.scratch.attn_head_major_out,
695            1,
696            tokens,
697            kv_len,
698            pos_offset,
699            &attn_cfg,
700        );
701
702        // 7. Untranspose head-major → token-major for O-proj input
703        B::transpose_head_to_token(
704            ctx,
705            &self.scratch.attn_head_major_out,
706            &mut self.scratch.attn_flat,
707            tokens,
708            nh,
709            hd,
710        );
711
712        // 8. O projection
713        layer.o_proj.forward(
714            ctx,
715            &self.scratch.attn_flat,
716            &mut self.scratch.o_proj_out,
717            tokens,
718        );
719
720        // 9. Fused residual-add + post-attention RMSNorm.
721        //    Writes the new residual back into `residual` and the normed
722        //    value into `norm_out`.
723        B::fused_add_rms_norm(
724            ctx,
725            residual,
726            &self.scratch.o_proj_out,
727            &layer.post_ln_w,
728            eps,
729            &mut self.scratch.norm_out,
730            tokens,
731            h,
732        );
733
734        // 10. Fused gate+up projection
735        layer.gate_up_proj.forward(
736            ctx,
737            &self.scratch.norm_out,
738            &mut self.scratch.gate_up_out,
739            tokens,
740        );
741
742        // 11. SwiGLU: silu(gate) * up
743        B::fused_silu_mul_split(
744            ctx,
745            &self.scratch.gate_up_out,
746            &mut self.scratch.silu_out,
747            tokens,
748            im,
749        );
750
751        // 12. Down projection
752        layer.down_proj.forward(
753            ctx,
754            &self.scratch.silu_out,
755            &mut self.scratch.mlp_out,
756            tokens,
757        );
758
759        // 13. Final residual add
760        B::add_inplace(ctx, residual, &self.scratch.mlp_out, tokens * h);
761    }
762
763    /// Prefill: process `tokens` prompt tokens in a single batch, return
764    /// `[vocab_size]` logits for the last position.
765    pub fn prefill_internal(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
766        let seq_len = tokens.len();
767        assert!(seq_len > 0, "prefill called with empty token list");
768        self.ensure_scratch(seq_len);
769        self.ensure_kv(cache_id);
770
771        let h = self.cfg.hidden_size;
772        let vocab = self.cfg.vocab_size;
773        let mut ctx = B::new_context();
774
775        // Move `residual` out of `scratch` to work around the borrow checker:
776        // `forward_layer` re-borrows `&mut self` to reach `self.layers` /
777        // `self.kv_caches`, which would conflict with an outstanding
778        // `&mut self.scratch.residual`. Use Option::take to move it out
779        // (no placeholder alloc → no transient cuMemFreeAsync that could
780        // corrupt stream pool state after graph ops on Blackwell).
781        let mut residual = self
782            .scratch
783            .residual
784            .take()
785            .expect("scratch residual missing (previous call didn't restore)");
786        let embed = self
787            .embed
788            .as_ref()
789            .expect("prefill_internal called on backbone-only model (no embed)");
790        B::embedding_lookup(&mut ctx, embed, tokens, &mut residual, h);
791
792        for li in 0..self.cfg.num_layers {
793            self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
794        }
795
796        // Take the last token's hidden state: residual[(seq_len-1)*h .. seq_len*h]
797        B::copy_slice(
798            &mut ctx,
799            &residual,
800            (seq_len - 1) * h,
801            &mut self.scratch.last_hidden,
802            0,
803            h,
804        );
805
806        // Final RMSNorm on the last hidden.
807        B::rms_norm(
808            &mut ctx,
809            &self.scratch.last_hidden,
810            &self.final_norm_w,
811            self.cfg.rms_norm_eps,
812            &mut self.scratch.last_normed,
813            1,
814            h,
815        );
816
817        // LM head (m=1 — triggers GEMV on MetalBackend).
818        let lm_head = self
819            .lm_head
820            .as_ref()
821            .expect("prefill_internal called on backbone-only model (no lm_head)");
822        lm_head.forward(
823            &mut ctx,
824            &self.scratch.last_normed,
825            &mut self.scratch.logits,
826            1,
827        );
828
829        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
830        // buffer's CPU pointer without flushing the command buffer, so the
831        // GPU must complete all pending work first or we read stale/random
832        // data. CUDA's to_vec does an internal stream.synchronize, making
833        // the call redundant there (~50µs/step cost), but correctness on
834        // Metal requires the explicit flush here.
835        B::sync(&mut ctx);
836
837        // Restore residual into scratch for reuse on the next call.
838        self.scratch.residual = Some(residual);
839
840        B::to_vec(&self.scratch.logits, vocab)
841    }
842
843    /// Decode: process 1 token at position `pos`, return `[vocab_size]` logits.
844    pub fn decode_internal(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
845        self.ensure_scratch(1);
846        self.ensure_kv(cache_id);
847
848        let h = self.cfg.hidden_size;
849        let vocab = self.cfg.vocab_size;
850
851        // Context creation is cheap (CUDA reuses the process-global stream).
852        // The captured graph lives in a process-global slot, not on ctx.
853        let mut ctx = B::new_context();
854
855        // Graph capture is opt-in via FERRUM_CUDA_GRAPH=1. Replay is currently
856        // single-request-only on Blackwell + CUDA 12.8 (see
857        // docs/phase-e-cuda-status.md). In pure eager mode, we skip the
858        // per-step device-state memcpy_htod trio entirely.
859        const GRAPH_WARMUP: usize = 3;
860        let graph_enabled = std::env::var("FERRUM_CUDA_GRAPH").is_ok();
861
862        if graph_enabled {
863            // Refresh device-side dynamic state (token/pos/kv_len) before
864            // replay — captured graph reads these from device buffers.
865            B::set_decode_state(&mut ctx, token, pos);
866
867            // Fast path: graph replay (if available).
868            match B::replay_last_graph(&mut ctx) {
869                Ok(true) => {
870                    B::sync(&mut ctx);
871                    return B::to_vec(&self.scratch.logits, vocab);
872                }
873                Ok(false) => { /* no graph yet, fall through to eager */ }
874                Err(_) => { /* backend error or unsupported, eager */ }
875            }
876        }
877
878        let should_capture =
879            graph_enabled && !self.graph_capture_failed && self.graph_warmup >= GRAPH_WARMUP;
880
881        if should_capture {
882            B::set_dev_state_mode(&mut ctx, true);
883            if B::begin_graph_capture(&mut ctx).is_err() {
884                self.graph_capture_failed = true;
885                B::set_dev_state_mode(&mut ctx, false);
886            }
887        }
888
889        // Eager forward (records into graph if capture is active).
890        // mem::replace needs a placeholder. B::alloc(0) was our choice but
891        // cuMemAllocFromPoolAsync(stream, 0) can return CUDA_ERROR_INVALID_VALUE
892        // on Blackwell after graph replay corrupts the pool state. Size-1 is
893        // always valid and costs 2 bytes of transient VRAM per decode step.
894        let mut residual = self
895            .scratch
896            .residual
897            .take()
898            .expect("scratch residual missing (previous call didn't restore)");
899        let embed = self
900            .embed
901            .as_ref()
902            .expect("decode_internal called on backbone-only model (no embed)");
903        B::embedding_lookup(&mut ctx, embed, &[token], &mut residual, h);
904
905        for li in 0..self.cfg.num_layers {
906            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
907        }
908
909        B::rms_norm(
910            &mut ctx,
911            &residual,
912            &self.final_norm_w,
913            self.cfg.rms_norm_eps,
914            &mut self.scratch.last_normed,
915            1,
916            h,
917        );
918
919        let lm_head = self
920            .lm_head
921            .as_ref()
922            .expect("decode_internal called on backbone-only model (no lm_head)");
923        lm_head.forward(
924            &mut ctx,
925            &self.scratch.last_normed,
926            &mut self.scratch.logits,
927            1,
928        );
929
930        if should_capture && !self.graph_capture_failed {
931            if B::end_graph_capture(&mut ctx).is_err() {
932                self.graph_capture_failed = true;
933            } else {
934                // Stream capture mode RECORDS ops into the graph without
935                // executing them. scratch.logits still holds the previous
936                // step's value. Replay the just-captured graph once to
937                // actually execute and produce this step's logits. Without
938                // this, the capture step's to_vec returns stale logits,
939                // yielding a 1-token offset in the generated sequence.
940                if B::replay_last_graph(&mut ctx).is_err() {
941                    self.graph_capture_failed = true;
942                }
943            }
944            B::set_dev_state_mode(&mut ctx, false);
945        } else {
946            self.graph_warmup += 1;
947        }
948
949        // Sync ctx before to_vec: on Metal, `to_vec` just reads the shared
950        // buffer's CPU pointer without flushing the command buffer, so the
951        // GPU must complete all pending work first or we read stale/random
952        // data. CUDA's to_vec does an internal stream.synchronize, making
953        // the call redundant there (~50µs/step cost), but correctness on
954        // Metal requires the explicit flush here.
955        B::sync(&mut ctx);
956        self.scratch.residual = Some(residual);
957
958        B::to_vec(&self.scratch.logits, vocab)
959    }
960
961    /// Prefill with pre-computed embeddings instead of token IDs.
962    ///
963    /// Used by models that embed inputs outside the LLM (e.g. Qwen3-TTS
964    /// mixes text-embedding + codec-embedding before feeding the LM).
965    /// Skips `final_norm` + `lm_head`; returns the last position's pre-norm
966    /// hidden state. Caller applies its own output head.
967    ///
968    /// `embeds` is row-major `[seq_len * hidden_size]`, f32.
969    pub fn prefill_from_embeds(
970        &mut self,
971        cache_id: &str,
972        embeds: &[f32],
973        seq_len: usize,
974    ) -> Vec<f32> {
975        let h = self.cfg.hidden_size;
976        assert_eq!(
977            embeds.len(),
978            seq_len * h,
979            "embeds length {} != seq_len * hidden_size {}",
980            embeds.len(),
981            seq_len * h
982        );
983        assert!(seq_len > 0, "prefill_from_embeds called with zero length");
984
985        self.ensure_scratch(seq_len);
986        self.ensure_kv(cache_id);
987
988        let mut ctx = B::new_context();
989        let mut residual = self
990            .scratch
991            .residual
992            .take()
993            .expect("scratch residual missing (previous call didn't restore)");
994
995        // Upload embeds → residual[0 .. seq_len*h].
996        let embed_buf = B::from_slice(embeds);
997        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
998
999        for li in 0..self.cfg.num_layers {
1000            self.forward_layer(&mut ctx, li, cache_id, &mut residual, 0, seq_len);
1001        }
1002
1003        B::copy_slice(
1004            &mut ctx,
1005            &residual,
1006            (seq_len - 1) * h,
1007            &mut self.scratch.last_hidden,
1008            0,
1009            h,
1010        );
1011        B::sync(&mut ctx);
1012        self.scratch.residual = Some(residual);
1013        B::to_vec(&self.scratch.last_hidden, h)
1014    }
1015
1016    /// Decode with a single pre-computed embedding (shape `[hidden]`).
1017    /// Returns the pre-norm hidden state for the position `pos`. Caller
1018    /// applies final norm + its own output head.
1019    pub fn decode_from_embed(&mut self, cache_id: &str, embed: &[f32], pos: u32) -> Vec<f32> {
1020        let h = self.cfg.hidden_size;
1021        assert_eq!(
1022            embed.len(),
1023            h,
1024            "embed length {} != hidden_size {}",
1025            embed.len(),
1026            h
1027        );
1028
1029        self.ensure_scratch(1);
1030        self.ensure_kv(cache_id);
1031
1032        let mut ctx = B::new_context();
1033        let mut residual = self
1034            .scratch
1035            .residual
1036            .take()
1037            .expect("scratch residual missing (previous call didn't restore)");
1038
1039        let embed_buf = B::from_slice(embed);
1040        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1041
1042        for li in 0..self.cfg.num_layers {
1043            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1044        }
1045
1046        B::copy_slice(&mut ctx, &residual, 0, &mut self.scratch.last_hidden, 0, h);
1047        B::sync(&mut ctx);
1048        self.scratch.residual = Some(residual);
1049        B::to_vec(&self.scratch.last_hidden, h)
1050    }
1051
1052    /// Variant of `prefill_from_embeds` that applies `final_norm` to every
1053    /// position and returns the whole `[seq_len * hidden_size]` vector.
1054    /// Accepts `pos_offset` so callers can continue an existing sequence
1055    /// (e.g. Qwen3-TTS voice-clone: one prefill for the role prefix, a
1056    /// follow-up prefill for the reference-audio ICL block, then
1057    /// autoregressive decoding — all against the same KV cache).
1058    ///
1059    /// Used by TTS where `forward_step` in the candle-based wrapper is
1060    /// expected to return **post-norm all-positions** hidden state so
1061    /// `codec_head` can be applied on candle side.
1062    pub fn prefill_all_post_norm(
1063        &mut self,
1064        cache_id: &str,
1065        embeds: &[f32],
1066        seq_len: usize,
1067        pos_offset: usize,
1068    ) -> Vec<f32> {
1069        let h = self.cfg.hidden_size;
1070        assert_eq!(
1071            embeds.len(),
1072            seq_len * h,
1073            "embeds length {} != seq_len * hidden_size {}",
1074            embeds.len(),
1075            seq_len * h
1076        );
1077        assert!(seq_len > 0);
1078
1079        self.ensure_scratch(seq_len);
1080        self.ensure_kv(cache_id);
1081
1082        let mut ctx = B::new_context();
1083        let mut residual = self
1084            .scratch
1085            .residual
1086            .take()
1087            .expect("scratch residual missing (previous call didn't restore)");
1088
1089        let embed_buf = B::from_slice(embeds);
1090        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, seq_len * h);
1091
1092        for li in 0..self.cfg.num_layers {
1093            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos_offset, seq_len);
1094        }
1095
1096        // Apply final_norm over all seq_len positions → scratch.norm_out.
1097        B::rms_norm(
1098            &mut ctx,
1099            &residual,
1100            &self.final_norm_w,
1101            self.cfg.rms_norm_eps,
1102            &mut self.scratch.norm_out,
1103            seq_len,
1104            h,
1105        );
1106        B::sync(&mut ctx);
1107        self.scratch.residual = Some(residual);
1108        B::to_vec(&self.scratch.norm_out, seq_len * h)
1109    }
1110
1111    /// Decode-side companion to `prefill_all_post_norm`. Runs a single-token
1112    /// decode step at `pos`, applies `final_norm`, and returns the post-norm
1113    /// hidden state `[hidden_size]`.
1114    pub fn decode_post_norm_from_embed(
1115        &mut self,
1116        cache_id: &str,
1117        embed: &[f32],
1118        pos: u32,
1119    ) -> Vec<f32> {
1120        let h = self.cfg.hidden_size;
1121        assert_eq!(embed.len(), h);
1122
1123        self.ensure_scratch(1);
1124        self.ensure_kv(cache_id);
1125
1126        let mut ctx = B::new_context();
1127        let mut residual = self
1128            .scratch
1129            .residual
1130            .take()
1131            .expect("scratch residual missing (previous call didn't restore)");
1132
1133        let embed_buf = B::from_slice(embed);
1134        B::copy_slice(&mut ctx, &embed_buf, 0, &mut residual, 0, h);
1135
1136        for li in 0..self.cfg.num_layers {
1137            self.forward_layer(&mut ctx, li, cache_id, &mut residual, pos as usize, 1);
1138        }
1139
1140        B::rms_norm(
1141            &mut ctx,
1142            &residual,
1143            &self.final_norm_w,
1144            self.cfg.rms_norm_eps,
1145            &mut self.scratch.last_normed,
1146            1,
1147            h,
1148        );
1149        B::sync(&mut ctx);
1150        self.scratch.residual = Some(residual);
1151        B::to_vec(&self.scratch.last_normed, h)
1152    }
1153
1154    /// Batched decode: process M concurrent requests at potentially different
1155    /// positions in one forward pass. GEMM-heavy ops (qkv_proj, o_proj,
1156    /// gate_up, down) run with m=M for natural batching; rope + KV append +
1157    /// attention loop per-item (each has its own KV cache at a different
1158    /// kv_len, and potentially different pos).
1159    ///
1160    /// Returns M logit vectors in the same order as `batch`.
1161    pub fn decode_batch_internal(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1162        let m = batch.len();
1163        if m == 0 {
1164            return Vec::new();
1165        }
1166        if m == 1 {
1167            let (cid, tok, pos) = &batch[0];
1168            return vec![self.decode_internal(cid, *tok, *pos)];
1169        }
1170
1171        // Ensure all caches exist and scratch is sized for M tokens.
1172        for (cid, _, _) in batch {
1173            self.ensure_kv(cid);
1174        }
1175        self.ensure_scratch(m);
1176
1177        let h = self.cfg.hidden_size;
1178        let vocab = self.cfg.vocab_size;
1179        let mut ctx = B::new_context();
1180
1181        // 0. Embed all M tokens into residual [M, H]
1182        let tokens: Vec<u32> = batch.iter().map(|(_, t, _)| *t).collect();
1183        let mut residual = self
1184            .scratch
1185            .residual
1186            .take()
1187            .expect("scratch residual missing (previous call didn't restore)");
1188        let embed = self
1189            .embed
1190            .as_ref()
1191            .expect("decode_batch_internal called on backbone-only model (no embed)");
1192        B::embedding_lookup(&mut ctx, embed, &tokens, &mut residual, h);
1193
1194        // 1..num_layers: batched forward for each layer
1195        for li in 0..self.cfg.num_layers {
1196            self.forward_layer_batched_decode(&mut ctx, li, batch, &mut residual, m);
1197        }
1198
1199        // Final RMSNorm on [M, H] → norm_out [M, H]
1200        B::rms_norm(
1201            &mut ctx,
1202            &residual,
1203            &self.final_norm_w,
1204            self.cfg.rms_norm_eps,
1205            &mut self.scratch.norm_out,
1206            m,
1207            h,
1208        );
1209
1210        // LM head with m=M → batch_logits [M, vocab]
1211        let lm_head = self
1212            .lm_head
1213            .as_ref()
1214            .expect("decode_batch_internal called on backbone-only model (no lm_head)");
1215        lm_head.forward(
1216            &mut ctx,
1217            &self.scratch.norm_out,
1218            &mut self.scratch.batch_logits,
1219            m,
1220        );
1221
1222        // Sync before to_vec (Metal: no internal sync on buffer read).
1223        B::sync(&mut ctx);
1224        self.scratch.residual = Some(residual);
1225
1226        // Extract M logit vectors from the flat buffer.
1227        let all = B::to_vec(&self.scratch.batch_logits, m * vocab);
1228        (0..m)
1229            .map(|i| all[i * vocab..(i + 1) * vocab].to_vec())
1230            .collect()
1231    }
1232
1233    /// One transformer layer over M items, GEMMs batched + per-item attention.
1234    fn forward_layer_batched_decode(
1235        &mut self,
1236        ctx: &mut B::Context,
1237        li: usize,
1238        batch: &[(String, u32, u32)],
1239        residual: &mut B::Buffer,
1240        m: usize,
1241    ) {
1242        let cfg = &self.cfg;
1243        let h = cfg.hidden_size;
1244        let nh = cfg.num_heads;
1245        let nkv = cfg.num_kv_heads;
1246        let hd = cfg.head_dim;
1247        let im = cfg.intermediate_size;
1248        let eps = cfg.rms_norm_eps;
1249        let q_dim = nh * hd;
1250        let kv_dim = nkv * hd;
1251
1252        let layer = &self.layers[li];
1253        let qk_mode: i32 = if cfg.has_qk_norm { 1 } else { 2 };
1254        let dummy_w = &layer.input_ln_w;
1255        let q_norm_w = layer.q_norm_w.as_ref().unwrap_or(dummy_w);
1256        let k_norm_w = layer.k_norm_w.as_ref().unwrap_or(dummy_w);
1257
1258        // 1. rms_norm [M, H]  → norm_out
1259        B::rms_norm(
1260            ctx,
1261            residual,
1262            &layer.input_ln_w,
1263            eps,
1264            &mut self.scratch.norm_out,
1265            m,
1266            h,
1267        );
1268
1269        // 2. qkv_proj (GEMM m=M): norm_out [M, H] → qkv_out [M, QKV]
1270        layer
1271            .qkv_proj
1272            .forward(ctx, &self.scratch.norm_out, &mut self.scratch.qkv_out, m);
1273
1274        // 3. split_qkv [M, QKV] → q_buf [M, Q], k_buf [M, KV], v_buf [M, KV]
1275        B::split_qkv(
1276            ctx,
1277            &self.scratch.qkv_out,
1278            &mut self.scratch.q_buf,
1279            &mut self.scratch.k_buf,
1280            &mut self.scratch.v_buf,
1281            m,
1282            q_dim,
1283            kv_dim,
1284        );
1285
1286        // 4-6. Per-item loop for rope + kv_append + attention.
1287        //      Each item has its own cache_id + pos + kv_len.
1288        for (i, (cache_id, _token, pos)) in batch.iter().enumerate() {
1289            let pos_i = *pos as usize;
1290
1291            // Extract item i's Q/K/V from batched buffers.
1292            B::copy_slice(
1293                ctx,
1294                &self.scratch.q_buf,
1295                i * q_dim,
1296                &mut self.scratch.q_single,
1297                0,
1298                q_dim,
1299            );
1300            B::copy_slice(
1301                ctx,
1302                &self.scratch.k_buf,
1303                i * kv_dim,
1304                &mut self.scratch.k_single,
1305                0,
1306                kv_dim,
1307            );
1308            B::copy_slice(
1309                ctx,
1310                &self.scratch.v_buf,
1311                i * kv_dim,
1312                &mut self.scratch.v_single,
1313                0,
1314                kv_dim,
1315            );
1316
1317            // qk_norm_rope with tokens=1, per-item pos.
1318            B::qk_norm_rope(
1319                ctx,
1320                &self.scratch.q_single,
1321                q_norm_w,
1322                &self.rope.cos,
1323                &self.rope.sin,
1324                &mut self.scratch.q_head_major_single,
1325                1,
1326                nh,
1327                hd,
1328                pos_i,
1329                eps,
1330                qk_mode,
1331            );
1332            B::qk_norm_rope(
1333                ctx,
1334                &self.scratch.k_single,
1335                k_norm_w,
1336                &self.rope.cos,
1337                &self.rope.sin,
1338                &mut self.scratch.k_head_major_single,
1339                1,
1340                nkv,
1341                hd,
1342                pos_i,
1343                eps,
1344                qk_mode,
1345            );
1346            B::qk_norm_rope(
1347                ctx,
1348                &self.scratch.v_single,
1349                dummy_w,
1350                &self.rope.cos,
1351                &self.rope.sin,
1352                &mut self.scratch.v_head_major_single,
1353                1,
1354                nkv,
1355                hd,
1356                pos_i,
1357                eps,
1358                0,
1359            );
1360
1361            // KV append + attention for item i's cache.
1362            let caches = self
1363                .kv_caches
1364                .get_mut(cache_id)
1365                .expect("ensure_kv must be called before forward_layer_batched");
1366            let cache = &mut caches[li];
1367            B::kv_cache_append_head_major(
1368                ctx,
1369                &mut cache.k,
1370                &mut cache.v,
1371                cache.len,
1372                cache.capacity,
1373                &self.scratch.k_head_major_single,
1374                &self.scratch.v_head_major_single,
1375                1,
1376                nkv,
1377                hd,
1378            );
1379            cache.len += 1;
1380            let kv_len = cache.len;
1381            let kv_stride = cache.capacity;
1382
1383            let attn_cfg = ferrum_kernels::backend::AttnConfig {
1384                num_heads: nh,
1385                num_kv_heads: nkv,
1386                head_dim: hd,
1387                causal: true,
1388                scale: 1.0 / (hd as f32).sqrt(),
1389                kv_seq_stride: kv_stride,
1390                sliding_window: cfg.sliding_window,
1391            };
1392            B::flash_attention(
1393                ctx,
1394                &self.scratch.q_head_major_single,
1395                &cache.k,
1396                &cache.v,
1397                &mut self.scratch.attn_head_major_single,
1398                1,
1399                1,
1400                kv_len,
1401                pos_i,
1402                &attn_cfg,
1403            );
1404
1405            // Untranspose head-major → token-major (tokens=1 → just contiguous).
1406            B::transpose_head_to_token(
1407                ctx,
1408                &self.scratch.attn_head_major_single,
1409                &mut self.scratch.attn_flat_single,
1410                1,
1411                nh,
1412                hd,
1413            );
1414
1415            // Inject item i's attn output into batched attn_flat [M, Q].
1416            B::copy_slice(
1417                ctx,
1418                &self.scratch.attn_flat_single,
1419                0,
1420                &mut self.scratch.attn_flat,
1421                i * q_dim,
1422                q_dim,
1423            );
1424        }
1425
1426        // 7. o_proj (GEMM m=M): attn_flat [M, Q] → o_proj_out [M, H]
1427        layer.o_proj.forward(
1428            ctx,
1429            &self.scratch.attn_flat,
1430            &mut self.scratch.o_proj_out,
1431            m,
1432        );
1433
1434        // 8. Fused residual add + post-attention RMSNorm.
1435        B::fused_add_rms_norm(
1436            ctx,
1437            residual,
1438            &self.scratch.o_proj_out,
1439            &layer.post_ln_w,
1440            eps,
1441            &mut self.scratch.norm_out,
1442            m,
1443            h,
1444        );
1445
1446        // 9. gate_up_proj (GEMM m=M)
1447        layer.gate_up_proj.forward(
1448            ctx,
1449            &self.scratch.norm_out,
1450            &mut self.scratch.gate_up_out,
1451            m,
1452        );
1453
1454        // 10. SwiGLU
1455        B::fused_silu_mul_split(
1456            ctx,
1457            &self.scratch.gate_up_out,
1458            &mut self.scratch.silu_out,
1459            m,
1460            im,
1461        );
1462
1463        // 11. down_proj (GEMM m=M)
1464        layer
1465            .down_proj
1466            .forward(ctx, &self.scratch.silu_out, &mut self.scratch.mlp_out, m);
1467
1468        // 12. Residual add
1469        B::add_inplace(ctx, residual, &self.scratch.mlp_out, m * h);
1470    }
1471}
1472
1473impl<B: Backend> DecoderOnlyLLM for LlamaFamilyModel<B> {
1474    fn config(&self) -> &LlmRuntimeConfig {
1475        &self.runtime_cfg
1476    }
1477
1478    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
1479        self.prefill_internal(cache_id, tokens)
1480    }
1481
1482    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
1483        self.decode_internal(cache_id, token, pos)
1484    }
1485
1486    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
1487        self.decode_batch_internal(batch)
1488    }
1489
1490    fn release(&mut self, cache_id: &str) {
1491        // Sync + drop graph BEFORE touching cache buffers. The graph was
1492        // actively running replays up to this point; destroying the graph
1493        // while the allocator pool still has in-flight references from the
1494        // graph's kernels corrupts stream state. Sync first to drain, then
1495        // destroy graph, then sync again to ensure cleanup completes.
1496        let mut ctx = B::new_context();
1497        B::sync(&mut ctx);
1498        B::reset_graph(&mut ctx);
1499        B::sync(&mut ctx);
1500        self.graph_warmup = 0;
1501        self.graph_capture_failed = false;
1502
1503        // Return the cache's buffers to the free pool instead of dropping.
1504        // Pointers stay stable for the next request's captured graph.
1505        if let Some(caches) = self.kv_caches.remove(cache_id) {
1506            self.kv_free_pool.push(caches);
1507        }
1508    }
1509
1510    fn reset(&mut self) {
1511        // Hard reset: drop all caches AND the pool, invalidate graph.
1512        let mut ctx = B::new_context();
1513        B::sync(&mut ctx);
1514        B::reset_graph(&mut ctx);
1515        B::sync(&mut ctx);
1516        self.graph_warmup = 0;
1517        self.graph_capture_failed = false;
1518        self.kv_caches.clear();
1519        self.kv_free_pool.clear();
1520    }
1521}
1522
1523fn build_rope_cache<B: Backend>(cfg: &LlamaFamilyConfig) -> RopeCache<B> {
1524    let hd = cfg.head_dim;
1525    let half = hd / 2;
1526    let max = cfg.max_seq_len;
1527    let mut cos = vec![0.0f32; max * half];
1528    let mut sin = vec![0.0f32; max * half];
1529    for pos in 0..max {
1530        for i in 0..half {
1531            let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
1532            let angle = pos as f64 * freq;
1533            cos[pos * half + i] = angle.cos() as f32;
1534            sin[pos * half + i] = angle.sin() as f32;
1535        }
1536    }
1537    RopeCache {
1538        cos: B::from_slice(&cos),
1539        sin: B::from_slice(&sin),
1540    }
1541}