Skip to main content

ferrum_kernels/backend/
traits.rs

1//! Core Backend trait — the single abstraction over CUDA / Metal / CPU.
2
3use ferrum_types::{FerrumError, Result};
4
5pub use super::capabilities::{
6    BackendCollective, BackendGraph, BackendMoeFused, BackendQuantGguf, BackendQuantMarlin,
7};
8pub use super::types::MoeRouting;
9use super::types::{AttnConfig, KvCacheQuant, SrcDtype};
10
11/// Maximum decode-graph layer count. Per-layer call sites that share
12/// graph-captured host staging arrays use this as the stride between
13/// distinct slots. CUDA-only invariant (other backends ignore the
14/// `slot` argument); 64 covers all current LLM families up to and
15/// including Llama-3-70B (80 layers — but 70B doesn't run on a single
16/// 4090 anyway, so 64 is safe in practice for v0.2).
17pub const MAX_LAYERS_FOR_GRAPH: usize = 64;
18
19// Note: `TransformerConfig` / `AttnType` / `MlpType` / `RopeConfig` used to
20// live here when `ModelRunner` needed a generic model config. They're now
21// per-model (e.g. `Qwen3Config` in `ferrum-models::models::qwen3`) so each
22// model can carry exactly the architecture parameters it cares about.
23// Backend trait stays model-agnostic.
24
25/// The core abstraction over CUDA / Metal / CPU.
26///
27/// Key design: operations take a `&mut Self::Context` which accumulates work.
28///   - **CPU**: Context is `()` — ops execute immediately.
29///   - **Metal**: Context is a `CommandBuffer` — ops encode into it, flushed on `sync()`.
30///   - **CUDA**: Context is a `CudaStream` — ops launch on the stream, synced on `sync()`.
31///
32/// `layer_forward` passes the context through all ops in a layer.
33/// `ModelRunner` calls `sync()` only when it needs results (e.g., reading logits).
34pub trait Backend: Send + Sync + Sized + 'static {
35    type Buffer: Send + Sync;
36
37    /// Execution context that accumulates GPU work.
38    ///   - CPU: `()` (no-op, ops execute inline)
39    ///   - Metal: wraps a CommandBuffer
40    ///   - CUDA: wraps a CudaStream
41    type Context;
42
43    /// GPU-side timer scoped to this backend. See `super::timer` —
44    /// CPU: `Instant`; Metal: sync-wrap; CUDA: `cuEvent`.
45    /// PLAYBOOK § 1.1.
46    type Timer: super::timer::BackendTimer<Self>;
47
48    /// Factory for `Self::Timer` — exists so call sites that have a
49    /// `<B: Backend>` parameter can spawn a timer without importing the
50    /// concrete impl. PLAYBOOK § 1.2.
51    fn make_timer() -> Self::Timer;
52
53    /// Opaque per-backend GPTQ weight representation.
54    ///   - CPU: dequantized f32 weights (run as regular GEMM)
55    ///   - Metal: `()` — unsupported; `gemm_gptq` errors
56    // Note (Phase 3e/4 + Phase C):
57    // - `type QuantStore` (GGUF k-quant storage) was removed in Phase 3e/4
58    //   — stacked-expert MoE GGUF goes through Box<dyn StackedExpertGgufLinear<Self>>
59    //   returned by `load_quant_experts`.
60    // - `type GptqStore` (Marlin/dequant GPTQ storage) was removed in Phase C
61    //   step 4e — stacked-expert Marlin MoE goes through
62    //   Arc<dyn MarlinExpertStack<Self>> returned by `load_gptq_stacked`,
63    //   and single-tensor GPTQ goes through Box<dyn Linear<Self>> returned
64    //   by `load_gptq`. Adding a new Marlin-capable backend is purely a
65    //   new MarlinExpertStack<NewBackend> impl — no Backend trait edits.
66
67    /// Create a new execution context (begin accumulating work).
68    fn new_context() -> Self::Context;
69
70    /// Flush accumulated work and wait for completion.
71    /// CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
72    fn sync(ctx: &mut Self::Context);
73
74    // Graph capability moved to the `BackendGraph` supertrait at the end
75    // of this file. CUDA implements its overrides; Metal/CPU inherit
76    // unsupported defaults via empty `impl BackendGraph for X {}` blocks.
77
78    // ── GPTQ (INT4 quantization) ────────────────────────────────────────
79    //
80    // Two-step: load (once per weight) → gemm (per forward). The store
81    // holds whatever backend-specific format is fastest; caller code
82    // (GptqLinear) is dtype-agnostic.
83
84    /// Zero the first `len` elements of a Self::Buffer. CUDA path uses
85    /// cuMemsetD16Async; default returns unsupported.
86    fn zero_buffer(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize) -> Result<()> {
87        Err(FerrumError::unsupported(
88            "zero_buffer not implemented for this backend",
89        ))
90    }
91
92    /// Phase D step 2+3: unified typed allocator. Replaces per-dtype
93    /// `alloc_u32` / `alloc_typed_i32` / etc. The buffer is dtype-
94    /// tagged at the wrapper level (`CudaBuf::U32`, `MetalBuf` with
95    /// `Dtype::U32`, `CpuBuf::U32`), so reads/writes through `.as_<T>()`
96    /// accessors get the correct byte count automatically.
97    fn alloc_typed(dtype: super::Dtype, n: usize) -> Self::Buffer;
98
99    /// Upload typed host data — replaces `from_slice_i32` /
100    /// `from_slice_u32` etc. The host element type `T` carries its
101    /// `Dtype` via the `HostDtype` marker so dispatch in the impl
102    /// is a one-line `match T::DTYPE`.
103    fn from_slice_typed<T: super::HostDtype>(data: &[T]) -> Self::Buffer;
104
105    /// In-place typed write — replaces `write_u32` / `write_i32_into`
106    /// / `write_f32_into`. The buffer must already be dtype-tagged
107    /// matching `T::DTYPE` (typically alloc'd via `alloc_typed` or
108    /// `from_slice_typed`).
109    fn write_typed<T: super::HostDtype>(
110        ctx: &mut Self::Context,
111        dst: &mut Self::Buffer,
112        data: &[T],
113    );
114
115    // ── GEMM ────────────────────────────────────────────────────────────
116
117    fn gemm(
118        ctx: &mut Self::Context,
119        a: &Self::Buffer,
120        b: &Self::Buffer,
121        out: &mut Self::Buffer,
122        m: usize,
123        n: usize,
124        k: usize,
125    );
126
127    // ── Norms ───────────────────────────────────────────────────────────
128
129    fn rms_norm(
130        ctx: &mut Self::Context,
131        x: &Self::Buffer,
132        w: &Self::Buffer,
133        eps: f32,
134        out: &mut Self::Buffer,
135        tokens: usize,
136        dim: usize,
137    );
138
139    fn fused_add_rms_norm(
140        ctx: &mut Self::Context,
141        residual: &mut Self::Buffer,
142        x: &Self::Buffer,
143        w: &Self::Buffer,
144        eps: f32,
145        out: &mut Self::Buffer,
146        tokens: usize,
147        dim: usize,
148    );
149
150    // ── Attention ───────────────────────────────────────────────────────
151
152    fn flash_attention(
153        ctx: &mut Self::Context,
154        q: &Self::Buffer,
155        k: &Self::Buffer,
156        v: &Self::Buffer,
157        out: &mut Self::Buffer,
158        batch: usize,
159        q_len: usize,
160        kv_len: usize,
161        pos_offset: usize,
162        cfg: &AttnConfig,
163    );
164
165    /// Multi-Head Latent Attention — DeepSeek V2 / V3's compressed-KV
166    /// attention variant. Extension point only; no backend implements it
167    /// yet. DeepSeek V3 landing in Phase D/E will fill this in.
168    ///
169    /// `q`: full Q `[batch, num_heads, q_len, head_dim]`
170    /// `kv_compressed`: latent KV `[batch, kv_len, kv_lora_rank]`
171    /// `kv_rope`: per-position rope-applied key heads `[batch, kv_len, qk_rope_head_dim]`
172    /// `out`: `[batch, num_heads, q_len, head_dim]`
173    #[allow(clippy::too_many_arguments)]
174    fn mla_attention(
175        _ctx: &mut Self::Context,
176        _q: &Self::Buffer,
177        _kv_compressed: &Self::Buffer,
178        _kv_rope: &Self::Buffer,
179        _out: &mut Self::Buffer,
180        _batch: usize,
181        _q_len: usize,
182        _kv_len: usize,
183        _pos_offset: usize,
184        _cfg: &AttnConfig,
185        _kv_lora_rank: usize,
186        _qk_rope_head_dim: usize,
187    ) -> Result<()> {
188        Err(FerrumError::unsupported(
189            "mla_attention not implemented for this backend; required by \
190             DeepSeek V2/V3 (Phase D/E)",
191        ))
192    }
193
194    // ── Element-wise ────────────────────────────────────────────────────
195    //
196    // Models use `add_inplace` for residual updates and `copy_slice` for the
197    // row-extraction step in prefill. Offset-free copy / non-inplace add are
198    // not needed by the current Model-as-Code path; they can return later if
199    // a model actually requires them.
200
201    /// Copy `len` floats from `src[src_offset..]` to `dst[dst_offset..]`.
202    ///
203    /// Needed for Qwen3Model::prefill to pluck the last token's hidden state
204    /// out of `residual[seq_len, h]` without round-tripping through host RAM.
205    /// `Backend::copy` is the offset-free variant; `copy_slice` additionally
206    /// supports non-zero source and destination offsets.
207    fn copy_slice(
208        ctx: &mut Self::Context,
209        src: &Self::Buffer,
210        src_offset: usize,
211        dst: &mut Self::Buffer,
212        dst_offset: usize,
213        len: usize,
214    );
215
216    // ── Embedding ───────────────────────────────────────────────────────
217
218    fn embedding_lookup(
219        ctx: &mut Self::Context,
220        table: &Self::Buffer,
221        ids: &[u32],
222        out: &mut Self::Buffer,
223        dim: usize,
224    );
225
226    /// Device-buffer variant of `embedding_lookup` for graph-capturable
227    /// MoE routing — the gather step before phase-1 GEMM in
228    /// `moe_forward_bucketed`. The host-slice `embedding_lookup` does
229    /// `clone_htod(ids)` internally, which records stale host pointers
230    /// under CUDA Graph capture replay.
231    ///
232    /// `ids: &Self::Buffer` must be a device I32 buffer of `batch`
233    /// elements (e.g. `Qwen3MoeScratch::route_packed_idx_dev`).
234    /// `batch` is passed explicitly since a typed CudaBuf carries
235    /// its element count but the caller often wants a partial gather.
236    ///
237    /// Default impl: round-trip via `to_vec` + dispatch the host-slice
238    /// variant. CUDA overrides.
239    fn embedding_lookup_dev(
240        ctx: &mut Self::Context,
241        table: &Self::Buffer,
242        ids: &Self::Buffer,
243        out: &mut Self::Buffer,
244        batch: usize,
245        dim: usize,
246    ) {
247        // Default: round-trip. CUDA overrides with a direct device-arg
248        // kernel launch (no clone_htod).
249        let ids_host_f32 = Self::to_vec(ids, batch);
250        let ids_host_u32: Vec<u32> = ids_host_f32.iter().map(|x| x.to_bits()).collect();
251        Self::embedding_lookup(ctx, table, &ids_host_u32, out, dim);
252    }
253
254    // ── Transformer-specific fused ops ─────────────────────────────────
255    // These avoid CPU round-trips for data layout transformations.
256
257    /// Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers.
258    /// Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
259    fn split_qkv(
260        ctx: &mut Self::Context,
261        qkv: &Self::Buffer,
262        q: &mut Self::Buffer,
263        k: &mut Self::Buffer,
264        v: &mut Self::Buffer,
265        tokens: usize,
266        q_dim: usize,
267        kv_dim: usize,
268    );
269
270    /// Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im],
271    /// then compute SiLU(gate) * up → out [tokens, im].
272    fn fused_silu_mul_split(
273        ctx: &mut Self::Context,
274        gate_up: &Self::Buffer,
275        out: &mut Self::Buffer,
276        tokens: usize,
277        im: usize,
278    );
279
280    /// Fused QK-norm + RoPE + transpose-to-head-major.
281    ///
282    /// `mode` selects the operation:
283    ///   0 = transpose only (typical for V, which needs no norm and no RoPE)
284    ///   1 = per-head RMS norm + RoPE + transpose  (Q/K with QK-norm, Qwen3)
285    ///   2 = RoPE + transpose                       (Q/K without QK-norm, Llama/Mistral)
286    ///
287    /// input:   `[tokens, heads, head_dim]`  (token-major, output of split_qkv)
288    /// output:  `[heads, tokens, head_dim]`  (head-major, ready for flash_attn / kv_cache_append)
289    ///
290    /// `pos_offset` is the position of token 0 (decode uses current seq len;
291    /// prefill uses 0). Within the batch, positions are taken as `pos_offset + i`.
292    ///
293    /// This is the primary attention-input preparation op. Backends that have a
294    /// fused kernel (Metal's `qk_norm_rope_transpose_f32`) will be dramatically
295    /// faster than composing norm + rope + transpose separately; the CPU
296    /// fallback lowers to the individual ops.
297    #[allow(clippy::too_many_arguments)]
298    fn qk_norm_rope(
299        ctx: &mut Self::Context,
300        input: &Self::Buffer,
301        norm_w: &Self::Buffer,
302        cos: &Self::Buffer,
303        sin: &Self::Buffer,
304        output: &mut Self::Buffer,
305        tokens: usize,
306        heads: usize,
307        head_dim: usize,
308        pos_offset: usize,
309        eps: f32,
310        mode: i32,
311    );
312
313    /// Batched kv_cache_append across M caches in one launch. Each item
314    /// writes its (head-major) K-or-V row into its own cache at offset
315    /// read from `cache_lens[i]`. Replaces M sequential
316    /// `kv_cache_append_head_major` calls with a single dispatch.
317    ///
318    /// `new_data` layout: `[m, nkv, hd]` item-major (each item's slice
319    /// is contiguous, identical to the `k/v_normed_batched` produced by
320    /// `qk_norm_rope_batched_per_item`).
321    /// `caches`: per-cache `[nkv, capacity, hd]` head-major.
322    /// `cache_lens`: device buffer (u32 storage, length ≥ m). Caller
323    /// fills via `B::write_u32_into` BEFORE the call. Required for
324    /// CUDA-graph capture: the kernel reads from this stable device
325    /// buffer, so a captured graph can be replayed with new lens by
326    /// just rewriting the buffer between launches.
327    fn kv_cache_append_batched_per_cache(
328        _ctx: &mut Self::Context,
329        _caches: &[&Self::Buffer],
330        _new_data: &Self::Buffer,
331        _cache_lens: &Self::Buffer,
332        _capacity: usize,
333        _m: usize,
334        _nkv: usize,
335        _hd: usize,
336        _slot: usize,
337    ) -> Result<()> {
338        Err(FerrumError::unsupported(
339            "kv_cache_append_batched_per_cache not implemented for this backend",
340        ))
341    }
342
343    /// Batched flash_attention across M decode caches in one launch.
344    /// Replaces the per-item `flash_attention(q_len=1, ...)` × M
345    /// loop in the non-paged batched-decode path.
346    ///
347    /// API takes Vec<&Buffer> for the per-cache K/V buffers (each
348    /// `[nkv, capacity, hd]` head-major) plus host-side `kv_lens`.
349    /// Backends that implement it must extract per-cache device
350    /// pointers, build the device arrays the kernel needs, and launch
351    /// one kernel covering all M items.
352    ///
353    /// `q` layout: [m, nq, hd] item-major (matches the
354    /// `qk_norm_rope_batched_per_item` output for q_len=1).
355    /// `out` layout: [m, nq, hd] item-major — written directly into
356    /// the caller's batched attn_out buffer, no per-item copy needed.
357    ///
358    /// CUDA-only for now (kernel `batched_decode_attention` exists in
359    /// `kernels/batched_decode_attention.cu`).
360    /// `kv_lens`: device buffer (u32 storage, length ≥ m) — same
361    /// design as `kv_cache_append_batched_per_cache::cache_lens`.
362    fn flash_attention_batched_per_cache(
363        _ctx: &mut Self::Context,
364        _q: &Self::Buffer,
365        _k_caches: &[&Self::Buffer],
366        _v_caches: &[&Self::Buffer],
367        _kv_lens: &Self::Buffer,
368        _out: &mut Self::Buffer,
369        _nq: usize,
370        _nkv: usize,
371        _hd: usize,
372        _scale: f32,
373        _max_valid_kv: usize,
374        _capacity: usize,
375        _slot: usize,
376    ) -> Result<()> {
377        Err(FerrumError::unsupported(
378            "flash_attention_batched_per_cache not implemented for this backend",
379        ))
380    }
381
382    /// Batched per-item-position variant of `qk_norm_rope` for the
383    /// non-paged batched-decode path. Each of the `m` items has its own
384    /// absolute RoPE position (read from a device i32 buffer of length
385    /// `m`). Layout is item-major in *both* input and output:
386    ///
387    ///   input  [m, heads, head_dim]
388    ///   output [m, heads, head_dim]   (no head-major transpose)
389    ///
390    /// Item-major output keeps the per-item flash_attention slice
391    /// contiguous (`output[i * heads * head_dim ..]` is item i's whole
392    /// Q tensor in head-major-equivalent layout for q_len=1).
393    ///
394    /// Replaces the M sequential single-item launches in the existing
395    /// `forward_layer_batched_decode` path with one batched dispatch.
396    /// CUDA-only for now; other backends fall through to the default
397    /// `unsupported` and the caller falls back to the per-item loop.
398    fn qk_norm_rope_batched_per_item(
399        _ctx: &mut Self::Context,
400        _input: &Self::Buffer,
401        _norm_w: &Self::Buffer,
402        _cos: &Self::Buffer,
403        _sin: &Self::Buffer,
404        _output: &mut Self::Buffer,
405        _positions: &Self::Buffer,
406        _m: usize,
407        _heads: usize,
408        _head_dim: usize,
409        _eps: f32,
410        _mode: i32,
411    ) -> Result<()> {
412        Err(FerrumError::unsupported(
413            "qk_norm_rope_batched_per_item not implemented for this backend",
414        ))
415    }
416
417    /// Fused split-QKV + QK-norm + RoPE + head-major transpose.
418    ///
419    /// Single-dispatch replacement for the (`split_qkv` → 3× `qk_norm_rope`)
420    /// chain on the decode-attention prelude. Reads the linear-layer
421    /// fused-QKV output once and writes head-major Q/K/V directly into
422    /// attention scratch.
423    ///
424    /// `qkv` layout: `[tokens, q_heads*hd + 2*kv_heads*hd]`.
425    /// `q_out`: `[q_heads, tokens, hd]`. `k_out`/`v_out`: `[kv_heads, tokens, hd]`.
426    /// `qk_mode`: 1 = norm + RoPE for Q/K (Qwen3 with QK-norm),
427    ///            2 = RoPE only for Q/K (no QK-norm; Llama-style).
428    /// V always falls through to transpose-only.
429    ///
430    /// Default returns Unsupported. Backends that implement it are
431    /// expected to be dramatically faster than the four-dispatch chain.
432    #[allow(clippy::too_many_arguments)]
433    fn split_qkv_norm_rope(
434        _ctx: &mut Self::Context,
435        _qkv: &Self::Buffer,
436        _q_norm_w: &Self::Buffer,
437        _k_norm_w: &Self::Buffer,
438        _cos: &Self::Buffer,
439        _sin: &Self::Buffer,
440        _q_out: &mut Self::Buffer,
441        _k_out: &mut Self::Buffer,
442        _v_out: &mut Self::Buffer,
443        _tokens: usize,
444        _q_heads: usize,
445        _kv_heads: usize,
446        _head_dim: usize,
447        _pos_offset: usize,
448        _eps: f32,
449        _qk_mode: i32,
450    ) -> Result<()> {
451        Err(FerrumError::unsupported(
452            "split_qkv_norm_rope not implemented for this backend",
453        ))
454    }
455
456    /// Variant of [`Backend::split_qkv_norm_rope`] that writes the new
457    /// K and V directly into pre-allocated head-major KV cache buffers
458    /// at slot `[kv_heads, cache_len .. cache_len + tokens, hd]`.
459    /// Eliminates the trailing `kv_cache_append_head_major` dispatch on
460    /// the decode hot path. Q still lands in per-token head-major
461    /// scratch (flash-attention reads it as the query).
462    ///
463    /// Default returns Unsupported. Backends without the fused kernel
464    /// can keep using `split_qkv_norm_rope` + `kv_cache_append_head_major`.
465    #[allow(clippy::too_many_arguments)]
466    fn split_qkv_norm_rope_into_cache(
467        _ctx: &mut Self::Context,
468        _qkv: &Self::Buffer,
469        _q_norm_w: &Self::Buffer,
470        _k_norm_w: &Self::Buffer,
471        _cos: &Self::Buffer,
472        _sin: &Self::Buffer,
473        _q_out: &mut Self::Buffer,
474        _cache_k: &mut Self::Buffer,
475        _cache_v: &mut Self::Buffer,
476        _tokens: usize,
477        _q_heads: usize,
478        _kv_heads: usize,
479        _head_dim: usize,
480        _pos_offset: usize,
481        _eps: f32,
482        _qk_mode: i32,
483        _cache_len: usize,
484        _cache_capacity: usize,
485    ) -> Result<()> {
486        Err(FerrumError::unsupported(
487            "split_qkv_norm_rope_into_cache not implemented for this backend",
488        ))
489    }
490
491    // Phase D step 2: alloc_u32 / write_u32 deleted. Callers use the
492    // unified `alloc_typed(Dtype::U32, n)` + `write_typed(&[u32])` API
493    // declared above.
494
495    /// Append new K/V into a pre-allocated head-major cache buffer.
496    ///
497    /// `cache_k` / `cache_v`: `[nkv, capacity, hd]` (head-major, pre-allocated)
498    /// `new_k_head_major` / `new_v_head_major`: `[nkv, new_tokens, hd]`
499    ///   — produced directly by `qk_norm_rope`, no extra transpose needed.
500    ///
501    /// In-place append at slot `[nkv, cache_len..cache_len+new_tokens, hd]`.
502    /// Caller owns `cache_len` bookkeeping.
503    #[allow(clippy::too_many_arguments)]
504    fn kv_cache_append_head_major(
505        ctx: &mut Self::Context,
506        cache_k: &mut Self::Buffer,
507        cache_v: &mut Self::Buffer,
508        cache_len: usize,
509        cache_capacity: usize,
510        new_k_head_major: &Self::Buffer,
511        new_v_head_major: &Self::Buffer,
512        new_tokens: usize,
513        nkv: usize,
514        hd: usize,
515    );
516
517    /// Transpose [heads, tokens, dim] → [tokens, heads, dim].
518    /// Called after `flash_attention` to restore token-major layout for O-proj.
519    fn transpose_head_to_token(
520        ctx: &mut Self::Context,
521        src: &Self::Buffer,
522        dst: &mut Self::Buffer,
523        tokens: usize,
524        heads: usize,
525        dim: usize,
526    );
527
528    /// Inverse of `transpose_head_to_token`: [tokens, heads, dim] →
529    /// [heads, tokens, dim]. Used by the CUDA `paged_decode_attention`
530    /// wrapper to convert `paged_varlen_attention`'s token-major output
531    /// back to the head-major layout that Qwen3MoeModel expects.
532    /// Default panics — backends without a paged-KV CUDA path don't
533    /// hit this code.
534    fn transpose_token_to_head(
535        _ctx: &mut Self::Context,
536        _src: &Self::Buffer,
537        _dst: &mut Self::Buffer,
538        _tokens: usize,
539        _heads: usize,
540        _dim: usize,
541    ) {
542        panic!("transpose_token_to_head not implemented for this backend");
543    }
544
545    /// residual[i] += x[i] (in-place)
546    fn add_inplace(
547        ctx: &mut Self::Context,
548        residual: &mut Self::Buffer,
549        x: &Self::Buffer,
550        len: usize,
551    );
552
553    /// `dst[i] += scale * src[i]` — scalar-broadcast scaled add, in place.
554    ///
555    /// MoE per-token combine writes `out[b] += weight_k * expert_k(x[b])`
556    /// for each top-K expert; this primitive is the per-call accumulate.
557    /// Backends without a dedicated kernel can fall back to the default
558    /// implementation, which round-trips through host memory — correct,
559    /// but slow on a hot path. Override on any backend you actually
560    /// dispatch MoE on.
561    fn scaled_add_inplace(
562        _ctx: &mut Self::Context,
563        dst: &mut Self::Buffer,
564        src: &Self::Buffer,
565        scale: f32,
566        len: usize,
567    ) {
568        let mut dst_v = Self::to_vec(dst, len);
569        let src_v = Self::to_vec(src, len);
570        for i in 0..len {
571            dst_v[i] += scale * src_v[i];
572        }
573        // Move the new buffer into the slot pointed to by `dst`. Safe
574        // because `Self::Buffer: Send + Sync` and the old buffer is
575        // dropped here when overwritten.
576        *dst = Self::from_slice(&dst_v);
577    }
578
579    /// Strided variant of [`Backend::fused_silu_mul_split`] for the
580    /// bucketed MoE path: reads `gate_up` rows starting at
581    /// `in_row_offset`, writes `out` rows starting at `out_row_offset`.
582    #[allow(clippy::too_many_arguments)]
583    fn fused_silu_mul_split_strided(
584        _ctx: &mut Self::Context,
585        _gate_up: &Self::Buffer,
586        _in_row_offset: usize,
587        _out: &mut Self::Buffer,
588        _out_row_offset: usize,
589        _tokens: usize,
590        _intermediate: usize,
591    ) {
592        unimplemented!("fused_silu_mul_split_strided default impl missing");
593    }
594
595    /// Broadcast bias add: `data[r, c] += bias[c]` for every row.
596    /// Required by Bert / Clip / Whisper whose linear projections carry a bias.
597    fn add_bias(
598        ctx: &mut Self::Context,
599        data: &mut Self::Buffer,
600        bias: &Self::Buffer,
601        rows: usize,
602        cols: usize,
603    );
604
605    /// Full LayerNorm (mean + variance normalisation + affine), distinct from
606    /// the `rms_norm` used by Llama-family decoders.
607    ///   `out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]`
608    /// Where `mean` and `var` are reduced over the last dim (cols).
609    #[allow(clippy::too_many_arguments)]
610    fn layer_norm(
611        ctx: &mut Self::Context,
612        x: &Self::Buffer,
613        gamma: &Self::Buffer,
614        beta: &Self::Buffer,
615        eps: f32,
616        out: &mut Self::Buffer,
617        tokens: usize,
618        dim: usize,
619    );
620
621    /// Element-wise GELU activation (erf-based, matches PyTorch default).
622    fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
623
624    // ── Buffer management (context-free) ────────────────────────────────
625
626    fn alloc(len: usize) -> Self::Buffer;
627    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
628    fn from_slice(data: &[f32]) -> Self::Buffer;
629
630    /// Greedy-decode fast path: GPU argmax over each row of a
631    /// `[m, n]` FP16 logits buffer, returning the m token indices on the
632    /// host. Saves `m × n × 2` bytes of D2H per call (e.g. 19.5 MB at
633    /// c=32, vocab=152064) and the host-side argmax scan (~150 µs × m).
634    ///
635    /// Default impl falls back to the slow path: full `to_vec` + host
636    /// argmax. CUDA overrides with a native kernel + tiny D2H (m × 4 B).
637    /// Backends that don't override pay the same cost as
638    /// `to_vec` + host argmax, so callers can call this unconditionally.
639    fn argmax_rows_f16(
640        _ctx: &mut Self::Context,
641        logits: &Self::Buffer,
642        m: usize,
643        n: usize,
644    ) -> Result<Vec<u32>> {
645        let host = Self::to_vec(logits, m * n);
646        let mut out = Vec::with_capacity(m);
647        for row in 0..m {
648            let slice = &host[row * n..(row + 1) * n];
649            let mut max_idx = 0usize;
650            let mut max_val = f32::NEG_INFINITY;
651            for (i, &v) in slice.iter().enumerate() {
652                if v > max_val {
653                    max_val = v;
654                    max_idx = i;
655                }
656            }
657            out.push(max_idx as u32);
658        }
659        Ok(out)
660    }
661
662    /// Load a weight tensor straight from its on-disk byte representation,
663    /// letting the backend pick its preferred storage dtype.
664    ///
665    /// Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching
666    /// pre-existing loader behaviour. Backends override this to go straight
667    /// from raw bytes into a native half-precision buffer (e.g. Metal with
668    /// `FERRUM_METAL_DTYPE=f16`), avoiding the transient 2× RAM spike.
669    fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
670        let data = src_dtype.to_f32_vec(raw);
671        Self::from_slice(&data)
672    }
673
674    // (The Phase A3 unified `gemm_quant(QuantWeights, QuantKind)` stub
675    // that used to live here is superseded by the `load_quant` /
676    // `gemm_quant(QuantStore)` pair earlier in this trait — same idea,
677    // but the store hides the per-kind buffer layout so callers don't
678    // have to construct a per-kind `QuantWeights<'_, Self>` packet.)
679}
680
681// ════════════════════════════════════════════════════════════════════════
682// BackendPagedKv capability (vLLM-style paged KV cache + paged attention)
683// ════════════════════════════════════════════════════════════════════════
684//
685// Paged KV pool with block-table indirection, plus the paged attention
686// kernel variants that read through that indirection. CUDA + Metal both
687// implement the real kernels; CPU `impl BackendPagedKv for CpuBackend {}`
688// inherits unsupported defaults.
689
690/// Capability-trait for backends that support paged KV cache + paged attention.
691pub trait BackendPagedKv: Backend {
692    /// Whether this backend has a paged-KV decode path
693    /// (`paged_decode_attention` etc.). Currently true for Metal, false
694    /// for CPU. Used to decide the default of `FERRUM_METAL_PAGED_KV` —
695    /// the `serve` path should opt in automatically when supported so
696    /// users get the bench-quality concurrent-decode numbers without
697    /// having to learn the flag.
698    fn supports_paged_kv() -> bool {
699        false
700    }
701    /// Pre-populate the per-slot device-pointer scratch arrays used by
702    /// the batched kernels (`kv_cache_append_batched_per_cache` and
703    /// `flash_attention_batched_per_cache`). Required by the CUDA-graph
704    /// capture path: the captured graph contains only kernel launches
705    /// (no captured `memcpy_htod`), so the device scratch must be fresh
706    /// when the graph replays.
707    ///
708    /// Caller passes flat layer-major slices: `k_caches[li * m + i]` and
709    /// `v_caches[li * m + i]`. Backend extracts each cache's device
710    /// pointer and writes into its corresponding slot in the device
711    /// scratch via SYNCHRONOUS memcpy (not captured by stream capture).
712    ///
713    /// CUDA-only; other backends fall through to the default
714    /// `unsupported` and the caller skips the population call.
715    fn populate_batched_pointers(
716        _ctx: &mut Self::Context,
717        _k_caches: &[&Self::Buffer],
718        _v_caches: &[&Self::Buffer],
719        _num_layers: usize,
720        _m: usize,
721    ) -> Result<()> {
722        Err(FerrumError::unsupported(
723            "populate_batched_pointers not implemented for this backend",
724        ))
725    }
726    /// Paged-KV variant of [`Self::split_qkv_norm_rope_into_cache`].
727    ///
728    /// Same fused split + qk-norm + RoPE, but K/V are written into a
729    /// paged pool `[num_blocks, kv_heads, block_size, head_dim]`
730    /// indexed via `block_table[logical_block]` → physical_block.
731    /// Q still goes to head-major scratch.
732    ///
733    /// Default returns Unsupported. Backends that lack a paged kernel
734    /// keep using the contiguous variant.
735    /// `qkv_byte_offset` / `q_out_byte_offset` let the caller pass a
736    /// slice of a larger batched buffer (used by the multi-seq paged
737    /// path in `decode_batch_internal`). For single-seq dispatch they
738    /// should be 0.
739    #[allow(clippy::too_many_arguments)]
740    fn split_qkv_norm_rope_into_paged_cache(
741        _ctx: &mut Self::Context,
742        _qkv: &Self::Buffer,
743        _qkv_byte_offset: u64,
744        _q_norm_w: &Self::Buffer,
745        _k_norm_w: &Self::Buffer,
746        _cos: &Self::Buffer,
747        _sin: &Self::Buffer,
748        _q_out: &mut Self::Buffer,
749        _q_out_byte_offset: u64,
750        _cache_k: &mut Self::Buffer,
751        _cache_v: &mut Self::Buffer,
752        _block_table: &Self::Buffer,
753        _tokens: usize,
754        _q_heads: usize,
755        _kv_heads: usize,
756        _head_dim: usize,
757        _pos_offset: usize,
758        _eps: f32,
759        _qk_mode: i32,
760        _cache_len: usize,
761        _block_size: usize,
762        _max_num_blocks_per_seq: usize,
763    ) -> Result<()> {
764        Err(FerrumError::unsupported(
765            "split_qkv_norm_rope_into_paged_cache not implemented for this backend",
766        ))
767    }
768    /// Paged-KV variant of [`Self::flash_attention`].
769    ///
770    /// Decode (`q_len == 1`):
771    ///   `q`/`out`: `[num_seqs, num_heads, head_dim]` (token-major)
772    ///
773    /// Causal prefill (`q_len > 1`, single seq):
774    ///   `q`/`out`: `[num_heads, q_len, head_dim]` (head-major — the
775    ///              layout produced by `split_qkv_norm_rope_into_paged_cache`)
776    ///   The kernel applies a per-q-token causal mask using
777    ///   `context_lens[seq]` as the FINAL kv_len (= `pos_offset + q_len`):
778    ///   token i sees positions `[0, context_lens - q_len + 1 + i)`.
779    ///
780    /// Common to both:
781    ///   `k_pool`/`v_pool`: `[num_blocks, num_kv_heads, block_size, head_dim]`
782    ///   `block_tables`: `[num_seqs, max_num_blocks_per_seq]` u32
783    ///   `context_lens`: `[num_seqs]` u32
784    ///
785    /// Backends without a paged kernel return Unsupported; callers are
786    /// expected to fall back to contiguous KV.
787    #[allow(clippy::too_many_arguments)]
788    fn paged_decode_attention(
789        _ctx: &mut Self::Context,
790        _q: &Self::Buffer,
791        _k_pool: &Self::Buffer,
792        _v_pool: &Self::Buffer,
793        _out: &mut Self::Buffer,
794        _block_tables: &Self::Buffer,
795        _context_lens: &Self::Buffer,
796        _num_seqs: usize,
797        _num_heads: usize,
798        _num_kv_heads: usize,
799        _head_dim: usize,
800        _block_size: usize,
801        _max_num_blocks_per_seq: usize,
802        _q_len: usize,
803    ) -> Result<()> {
804        Err(FerrumError::unsupported(
805            "paged_decode_attention not implemented for this backend",
806        ))
807    }
808    /// Capability: does this backend implement
809    /// `split_qkv_norm_rope_into_paged_cache_varlen` and
810    /// `paged_varlen_attention`? Required by the unified mixed-batch
811    /// forward path used by `LlamaFamilyModel::unified_forward`. Default
812    /// false; backends that ship the varlen kernels override.
813    fn supports_varlen_qkv() -> bool {
814        false
815    }
816    /// Varlen variant of [`Self::split_qkv_norm_rope_into_paged_cache`].
817    ///
818    /// Single launch covering ALL sequences in the batch. Reads
819    /// `pos_offsets[seq]`, `cu_seqlens_q[seq]`, and the per-seq
820    /// block_table from device buffers — graph-capturable (the per-iter
821    /// state is in buffers, not kernel scalars). Replaces the per-item
822    /// dispatch loop in `unified_forward_layer` with one call.
823    ///
824    /// Layouts:
825    /// - `qkv`: `[m_total, q_dim + 2 * kv_dim]` token-major
826    /// - `q_out`: `[m_total, q_heads, head_dim]` token-major (matches
827    ///   what `paged_varlen_attention` reads)
828    /// - `cache_k` / `cache_v`: paged pool same as `paged_varlen_attention`
829    /// - `cu_seqlens_q`: `[num_seqs + 1]` u32 prefix sum
830    /// - `pos_offsets`: `[num_seqs]` u32, starting kv_pos per seq
831    /// - `block_tables`: `[num_seqs, max_blocks_per_seq]` i32 stacked
832    #[allow(clippy::too_many_arguments)]
833    fn split_qkv_norm_rope_into_paged_cache_varlen(
834        _ctx: &mut Self::Context,
835        _qkv: &Self::Buffer,
836        _q_norm_w: &Self::Buffer,
837        _k_norm_w: &Self::Buffer,
838        _cos: &Self::Buffer,
839        _sin: &Self::Buffer,
840        _q_out: &mut Self::Buffer,
841        _cache_k: &mut Self::Buffer,
842        _cache_v: &mut Self::Buffer,
843        _cu_seqlens_q: &Self::Buffer,
844        _pos_offsets: &Self::Buffer,
845        _block_tables: &Self::Buffer,
846        _num_seqs: usize,
847        _m_total: usize,
848        _q_heads: usize,
849        _kv_heads: usize,
850        _head_dim: usize,
851        _eps: f32,
852        _qk_mode: i32,
853        _block_size: usize,
854        _max_blocks_per_seq: usize,
855    ) -> Result<()> {
856        Err(FerrumError::unsupported(
857            "split_qkv_norm_rope_into_paged_cache_varlen not implemented for this backend",
858        ))
859    }
860    /// Variable-length paged attention with GQA + causal mask.
861    ///
862    /// Supports a unified mixed batch where each sequence contributes
863    /// 1 (decode) or N (prefill chunk) query tokens — the workhorse for
864    /// chunked-prefill. See `kernels/paged_varlen_attention.cu` for the
865    /// kernel itself.
866    ///
867    /// Layouts:
868    /// - `q` / `out`: `[total_q_tokens, num_heads, head_dim]` (token-
869    ///   major, FP16). `total_q_tokens` = `cu_seqlens_q[num_seqs]`.
870    /// - `k_pool` / `v_pool`: paged block pool, layout matches
871    ///   `paged_decode_attention`.
872    /// - `cu_seqlens_q`: `[num_seqs + 1]` u32 prefix sum, with
873    ///   `cu_seqlens_q[0] = 0` and `cu_seqlens_q[num_seqs] = total_q_tokens`.
874    /// - `pos_offsets`: `[num_seqs]` u32, the starting absolute KV
875    ///   position of each seq's first q token (= prior `kv_len`).
876    /// - `block_tables`: `[num_seqs, max_num_blocks_per_seq]` i32 grid.
877    ///
878    /// Each query token attends causally to all KV positions
879    /// `[0, pos_offsets[s] + local_idx]`.
880    #[allow(clippy::too_many_arguments)]
881    fn paged_varlen_attention(
882        _ctx: &mut Self::Context,
883        _q: &Self::Buffer,
884        _k_pool: &Self::Buffer,
885        _v_pool: &Self::Buffer,
886        _out: &mut Self::Buffer,
887        _cu_seqlens_q: &Self::Buffer,
888        _pos_offsets: &Self::Buffer,
889        _block_tables: &Self::Buffer,
890        _num_seqs: usize,
891        _total_q_tokens: usize,
892        _max_kv_len: usize,
893        _num_heads: usize,
894        _num_kv_heads: usize,
895        _head_dim: usize,
896        _block_size: usize,
897        _max_num_blocks_per_seq: usize,
898    ) -> Result<()> {
899        Err(FerrumError::unsupported(
900            "paged_varlen_attention not implemented for this backend",
901        ))
902    }
903
904    /// Opt-in vLLM FlashAttention-2 FFI path for FA-layout paged KV.
905    ///
906    /// This is intentionally separate from [`Self::paged_varlen_attention`]:
907    /// it needs the final per-sequence KV lengths (`seq_lens`) and an explicit
908    /// LSE scratch buffer because the external FA2 runner writes softmax LSE.
909    /// Default returns Err(unsupported); CUDA overrides when a runtime shim is
910    /// provided via `FERRUM_FA2_DIRECT_FFI_SHIM`.
911    #[allow(clippy::too_many_arguments)]
912    fn paged_varlen_attention_fa2_ffi(
913        _ctx: &mut Self::Context,
914        _q: &Self::Buffer,
915        _k_pool: &Self::Buffer,
916        _v_pool: &Self::Buffer,
917        _out: &mut Self::Buffer,
918        _lse: &mut Self::Buffer,
919        _cu_seqlens_q: &Self::Buffer,
920        _seq_lens: &Self::Buffer,
921        _block_tables: &Self::Buffer,
922        _num_seqs: usize,
923        _total_q_tokens: usize,
924        _max_q_len: usize,
925        _max_kv_len: usize,
926        _num_heads: usize,
927        _num_kv_heads: usize,
928        _head_dim: usize,
929        _block_size: usize,
930        _max_num_blocks_per_seq: usize,
931    ) -> Result<()> {
932        Err(FerrumError::unsupported(
933            "paged_varlen_attention_fa2_ffi not implemented for this backend",
934        ))
935    }
936
937    /// Batched paged decode attention — multi-seq, single token per seq.
938    /// Faster path for the unified_forward layer when m_total == num_seqs
939    /// (every item is a single-token decode). Skips the cu_seqlens_q
940    /// linear scan that `paged_varlen_attention` does in the fully-mixed
941    /// case.
942    ///
943    /// Layouts:
944    ///   q              : [num_seqs, num_q_heads, head_dim]
945    ///   k_pool/v_pool  : paged pool (same as paged_varlen)
946    ///   block_tables   : [num_seqs, max_num_blocks_per_seq]
947    ///   valid_kv_lens  : [num_seqs] — current kv_len per seq
948    ///   out            : [num_seqs, num_q_heads, head_dim]
949    ///
950    /// Default returns Err(unsupported); CUDA backend overrides.
951    #[allow(clippy::too_many_arguments)]
952    fn paged_batched_decode_attention(
953        _ctx: &mut Self::Context,
954        _q: &Self::Buffer,
955        _k_pool: &Self::Buffer,
956        _v_pool: &Self::Buffer,
957        _out: &mut Self::Buffer,
958        _block_tables: &Self::Buffer,
959        _valid_kv_lens: &Self::Buffer,
960        _num_seqs: usize,
961        _max_kv_len: usize,
962        _num_heads: usize,
963        _num_kv_heads: usize,
964        _head_dim: usize,
965        _block_size: usize,
966        _max_num_blocks_per_seq: usize,
967    ) -> Result<()> {
968        Err(FerrumError::unsupported(
969            "paged_batched_decode_attention not implemented for this backend",
970        ))
971    }
972
973    /// Capability: backend has vLLM-layout paged KV write kernels and the
974    /// `paged_attention_v2` decode kernel. Models that opt into this layout
975    /// at construction time (via `FERRUM_USE_VLLM_PAGED_ATTN=1`) must
976    /// dispatch ALL paged writes and reads through the `_vllm` variants —
977    /// the layouts are not compatible. Default `false`.
978    fn supports_vllm_paged_attn() -> bool {
979        false
980    }
981
982    /// vLLM-layout variant of
983    /// [`Self::split_qkv_norm_rope_into_paged_cache`]. K/V are written in
984    /// vLLM's `paged_attention_v2` layout: K is
985    /// `[num_blocks, kv_heads, head_dim/x, block_size, x]` (x = 16/sizeof(elem)),
986    /// V is `[num_blocks, kv_heads, head_dim, block_size]`. Q output and
987    /// every other argument matches the non-vllm variant exactly so the
988    /// model layer can swap dispatchers based on a single flag.
989    #[allow(clippy::too_many_arguments)]
990    fn split_qkv_norm_rope_into_paged_cache_vllm(
991        _ctx: &mut Self::Context,
992        _qkv: &Self::Buffer,
993        _qkv_byte_offset: u64,
994        _q_norm_w: &Self::Buffer,
995        _k_norm_w: &Self::Buffer,
996        _cos: &Self::Buffer,
997        _sin: &Self::Buffer,
998        _q_out: &mut Self::Buffer,
999        _q_out_byte_offset: u64,
1000        _cache_k: &mut Self::Buffer,
1001        _cache_v: &mut Self::Buffer,
1002        _block_table: &Self::Buffer,
1003        _tokens: usize,
1004        _q_heads: usize,
1005        _kv_heads: usize,
1006        _head_dim: usize,
1007        _pos_offset: usize,
1008        _eps: f32,
1009        _qk_mode: i32,
1010        _cache_len: usize,
1011        _block_size: usize,
1012        _max_num_blocks_per_seq: usize,
1013    ) -> Result<()> {
1014        Err(FerrumError::unsupported(
1015            "split_qkv_norm_rope_into_paged_cache_vllm not implemented for this backend",
1016        ))
1017    }
1018
1019    /// vLLM-layout variant of
1020    /// [`Self::split_qkv_norm_rope_into_paged_cache_varlen`]. Same signature
1021    /// — only the K/V cache layout changes.
1022    #[allow(clippy::too_many_arguments)]
1023    fn split_qkv_norm_rope_into_paged_cache_varlen_vllm(
1024        _ctx: &mut Self::Context,
1025        _qkv: &Self::Buffer,
1026        _q_norm_w: &Self::Buffer,
1027        _k_norm_w: &Self::Buffer,
1028        _cos: &Self::Buffer,
1029        _sin: &Self::Buffer,
1030        _q_out: &mut Self::Buffer,
1031        _cache_k: &mut Self::Buffer,
1032        _cache_v: &mut Self::Buffer,
1033        _cu_seqlens_q: &Self::Buffer,
1034        _pos_offsets: &Self::Buffer,
1035        _block_tables: &Self::Buffer,
1036        _num_seqs: usize,
1037        _m_total: usize,
1038        _q_heads: usize,
1039        _kv_heads: usize,
1040        _head_dim: usize,
1041        _eps: f32,
1042        _qk_mode: i32,
1043        _block_size: usize,
1044        _max_blocks_per_seq: usize,
1045    ) -> Result<()> {
1046        Err(FerrumError::unsupported(
1047            "split_qkv_norm_rope_into_paged_cache_varlen_vllm not implemented for this backend",
1048        ))
1049    }
1050
1051    /// vLLM `paged_attention_v2` — multi-partition split-K decode attention
1052    /// reading the vLLM K/V layout. `q_len` is implicitly 1 (decode only;
1053    /// vLLM's v2 kernel does not support q_len > 1). `max_seq_len` is the
1054    /// max kv_len across the batch — used to size the partition reduction.
1055    #[allow(clippy::too_many_arguments)]
1056    fn paged_decode_attention_v2(
1057        _ctx: &mut Self::Context,
1058        _q: &Self::Buffer,
1059        _k_pool: &Self::Buffer,
1060        _v_pool: &Self::Buffer,
1061        _out: &mut Self::Buffer,
1062        _block_tables: &Self::Buffer,
1063        _context_lens: &Self::Buffer,
1064        _num_seqs: usize,
1065        _num_heads: usize,
1066        _num_kv_heads: usize,
1067        _head_dim: usize,
1068        _block_size: usize,
1069        _max_num_blocks_per_seq: usize,
1070        _max_seq_len: usize,
1071    ) -> Result<()> {
1072        Err(FerrumError::unsupported(
1073            "paged_decode_attention_v2 not implemented for this backend",
1074        ))
1075    }
1076
1077    /// q_len>1 prefill/chunk-prefill attention over vLLM-layout paged KV.
1078    /// This keeps cache layout consistent when `FERRUM_USE_VLLM_PAGED_ATTN=1`
1079    /// and the prompt path writes K/V in the layout consumed later by
1080    /// `paged_decode_attention_v2`.
1081    #[allow(clippy::too_many_arguments)]
1082    fn paged_varlen_attention_vllm_layout(
1083        _ctx: &mut Self::Context,
1084        _q: &Self::Buffer,
1085        _k_pool: &Self::Buffer,
1086        _v_pool: &Self::Buffer,
1087        _out: &mut Self::Buffer,
1088        _block_tables: &Self::Buffer,
1089        _context_lens: &Self::Buffer,
1090        _num_seqs: usize,
1091        _num_heads: usize,
1092        _num_kv_heads: usize,
1093        _head_dim: usize,
1094        _block_size: usize,
1095        _max_num_blocks_per_seq: usize,
1096        _q_len: usize,
1097    ) -> Result<()> {
1098        Err(FerrumError::unsupported(
1099            "paged_varlen_attention_vllm_layout not implemented for this backend",
1100        ))
1101    }
1102
1103    /// Variable-length paged attention over vLLM-layout paged KV.
1104    ///
1105    /// Unlike [`Self::paged_varlen_attention_vllm_layout`], this accepts the
1106    /// same varlen index tensors as [`Self::paged_varlen_attention`] and writes
1107    /// token-major output directly. It is the unified mixed-batch companion for
1108    /// `split_qkv_norm_rope_into_paged_cache_varlen_vllm`.
1109    #[allow(clippy::too_many_arguments)]
1110    fn paged_varlen_attention_vllm(
1111        _ctx: &mut Self::Context,
1112        _q: &Self::Buffer,
1113        _k_pool: &Self::Buffer,
1114        _v_pool: &Self::Buffer,
1115        _out: &mut Self::Buffer,
1116        _cu_seqlens_q: &Self::Buffer,
1117        _pos_offsets: &Self::Buffer,
1118        _block_tables: &Self::Buffer,
1119        _num_seqs: usize,
1120        _total_q_tokens: usize,
1121        _max_kv_len: usize,
1122        _num_heads: usize,
1123        _num_kv_heads: usize,
1124        _head_dim: usize,
1125        _block_size: usize,
1126        _max_num_blocks_per_seq: usize,
1127    ) -> Result<()> {
1128        Err(FerrumError::unsupported(
1129            "paged_varlen_attention_vllm not implemented for this backend",
1130        ))
1131    }
1132
1133    /// Q-tiled vLLM-layout varlen attention. `tile_seqs` and `tile_starts`
1134    /// describe a compact list of q-token tiles, avoiding empty grid blocks
1135    /// for mixed batches that contain both long prefill items and q_len=1
1136    /// decode items. Semantics match [`Self::paged_varlen_attention_vllm`].
1137    #[allow(clippy::too_many_arguments)]
1138    fn paged_varlen_attention_vllm_tiled_q4(
1139        _ctx: &mut Self::Context,
1140        _q: &Self::Buffer,
1141        _k_pool: &Self::Buffer,
1142        _v_pool: &Self::Buffer,
1143        _out: &mut Self::Buffer,
1144        _cu_seqlens_q: &Self::Buffer,
1145        _pos_offsets: &Self::Buffer,
1146        _block_tables: &Self::Buffer,
1147        _tile_seqs: &Self::Buffer,
1148        _tile_starts: &Self::Buffer,
1149        _num_tiles: usize,
1150        _max_kv_len: usize,
1151        _num_heads: usize,
1152        _num_kv_heads: usize,
1153        _head_dim: usize,
1154        _block_size: usize,
1155        _max_num_blocks_per_seq: usize,
1156    ) -> Result<()> {
1157        Err(FerrumError::unsupported(
1158            "paged_varlen_attention_vllm_tiled_q4 not implemented for this backend",
1159        ))
1160    }
1161}
1162
1163// ════════════════════════════════════════════════════════════════════════
1164// Capability bundles — readable type aliases over the supertrait set
1165// ════════════════════════════════════════════════════════════════════════
1166//
1167// Models declare what they need via these bundles instead of spelling out
1168// every supertrait. Rust auto-derives the impl via blanket impls below,
1169// so any backend that satisfies the underlying supertraits automatically
1170// becomes a `LlmBackend` / `QuantLlmBackend` / `MoeLlmBackend`.
1171
1172/// Minimum capability set for a decoder-only LLM: the core compute trait
1173/// plus paged-KV cache + graph-capture support. Every concrete backend
1174/// (CUDA / Metal / CPU) satisfies this.
1175pub trait LlmBackend: Backend + BackendGraph + BackendPagedKv {}
1176impl<T> LlmBackend for T where T: Backend + BackendGraph + BackendPagedKv {}
1177
1178/// LLM backend that also supports quantized weight loading (GPTQ Marlin
1179/// for CUDA; GGUF k-quant for Metal). Required by models that hold
1180/// `Box<dyn Linear<B>>` where the Linear impl might be a quant variant.
1181pub trait QuantLlmBackend: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
1182impl<T> QuantLlmBackend for T where T: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
1183
1184/// MoE-capable LLM backend: adds the fused MoE routing + post-op kernels
1185/// to the quant LLM bundle. Required by Qwen3-MoE / future MoE models.
1186pub trait MoeLlmBackend: QuantLlmBackend + BackendMoeFused {}
1187impl<T> MoeLlmBackend for T where T: QuantLlmBackend + BackendMoeFused {}
1188
1189// ════════════════════════════════════════════════════════════════════════
1190// KV cache dtype axis (dim 5 of the 5-dimension architecture)
1191// ════════════════════════════════════════════════════════════════════════
1192//
1193// Each model's KV cache has its own precision independent of the model's
1194// compute precision. vLLM 0.6+ ships INT8 / FP8 KV caches that halve KV
1195// memory at small (<1%) accuracy hit. Today ferrum's KV is hardcoded
1196// FP16 on CUDA / Metal — to support INT8/FP8 KV in a future PR, the
1197// type system needs an explicit axis.
1198//
1199// Phase 4 scope: scaffolding only. All concrete backends impl
1200// `BackendKvDtype<KvFp16>` so existing models keep working unchanged.
1201// Future PR: implement BackendKvDtype<KvInt8> on CUDA + a new model
1202// type-parameter `K: KvDtypeKind` to wire it through.
1203
1204// `KvDtypeKind` + `KvFp16` / `KvBf16` / `KvInt8` / `KvFp8` markers moved
1205// to `ferrum_interfaces::kv_dtype` (no GPU deps, so the right place is
1206// the contract crate). Re-exported here so existing callers keep
1207// compiling against `crate::backend::KvFp16` etc.
1208pub use ferrum_interfaces::kv_dtype::{KvBf16, KvDtypeKind, KvFp16, KvFp8, KvInt8};
1209
1210/// Capability-trait for backends that can store + read a KV cache of
1211/// type `K`.
1212///
1213/// The two associated types carry the K-specific storage shape:
1214///   - `KvBuffer`: per-layer K/V element storage. For `K = KvFp16` it
1215///     is the backend's normal `Self::Buffer` (FP16). For `K = KvInt8`
1216///     it is the backend's INT8 buffer (e.g. `CudaSlice<i8>` on CUDA).
1217///   - `KvScales`: per-token-per-kv-head scales. For `K = KvFp16` this
1218///     is the unit type `()` (no scales). For `K = KvInt8` / `KvFp8`
1219///     it is a backend-specific FP16 buffer.
1220///
1221/// Models that want INT8 KV use:
1222///   `where B: BackendKvDtype<KvInt8>`
1223/// — the buffers in `KvCache<B, KvInt8>` are then `CudaSlice<i8>` and
1224/// `CudaSlice<f16>`, distinct from the FP16 path's `Self::Buffer`.
1225pub trait BackendKvDtype<K: KvDtypeKind>: BackendPagedKv {
1226    /// Per-layer K/V element storage.
1227    type KvBuffer: Send + Sync;
1228    /// Per-token per-kv-head scale storage. `()` for FP16 (no scales).
1229    type KvScales: Send + Sync + Default;
1230}
1231
1232/// INT8 KV cache operations (Dim 5).
1233///
1234/// `BackendKvDtype<KvInt8>` only declares the storage types; it does not
1235/// know how to write INT8 K/V into a paged pool or run paged decode
1236/// attention against an INT8 cache. Those launchers live here so the
1237/// model layer can call them through a single `B: BackendInt8KvOps` bound
1238/// without dropping into backend-specific code.
1239///
1240/// Today only `CudaBackend` provides a real implementation (delegating to
1241/// [`crate::int8_kv::launch_int8_kv_cache_append`] and
1242/// [`crate::int8_kv::launch_int8_paged_decode_attention`]). Other backends
1243/// inherit the default `unimplemented!()` body — the registry factory
1244/// rejects `(Device::CPU/Metal, KvCacheDtype::Int8)` before the model
1245/// gets a chance to call into these.
1246#[allow(clippy::too_many_arguments)]
1247pub trait BackendInt8KvOps: Backend + BackendKvDtype<KvInt8> {
1248    /// Allocate the per-layer INT8 paged cache for one sequence.
1249    /// Default panics — backends without INT8 support never reach this
1250    /// path (factory rejects (Cpu/Metal, Int8) before ensure_kv runs).
1251    fn alloc_paged_int8_layer(
1252        _max_blocks_per_seq: usize,
1253        _block_size: usize,
1254        _num_kv_heads: usize,
1255        _head_dim: usize,
1256    ) -> KvCacheQuant<Self, KvInt8> {
1257        unimplemented!("alloc_paged_int8_layer not supported on this backend")
1258    }
1259
1260    /// Append `tokens` FP16 K/V values into the paged INT8 pool.
1261    /// `paged_block_indices` is the host-side mirror of the per-seq
1262    /// logical→physical block table (already populated at `ensure_kv` time
1263    /// — see `KvCacheQuant::paged_block_indices`). Passing the host slice
1264    /// avoids a per-token D2H + sync barrier; backend computes the slot
1265    /// mapping host-side, async-H2D's it, and chains the append kernel
1266    /// on the same stream — fully overlapping with prior work.
1267    /// `cache_len_before` is the current number of valid tokens; the
1268    /// backend quantizes FP16 → INT8 with per-(token, kv-head) FP16 scale
1269    /// and writes both into the layer's INT8 / scale buffers.
1270    fn int8_kv_append_paged(
1271        _ctx: &mut Self::Context,
1272        _k_in: &Self::Buffer,
1273        _v_in: &Self::Buffer,
1274        _layer_k: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
1275        _layer_v: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
1276        _layer_k_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
1277        _layer_v_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
1278        _paged_block_indices: &[u32],
1279        _cache_len_before: usize,
1280        _tokens: usize,
1281        _block_size: usize,
1282        _num_kv_heads: usize,
1283        _head_dim: usize,
1284    ) -> Result<()> {
1285        Err(FerrumError::unsupported(
1286            "int8_kv_append_paged not implemented for this backend",
1287        ))
1288    }
1289
1290    /// Run paged decode attention reading from an INT8 cache. Q is FP16,
1291    /// output is FP16; the kernel dequantizes K/V on the fly using the
1292    /// per-token scales. `valid_kv_len` is the post-append cache length
1293    /// (i.e. the kernel attends over `[0, valid_kv_len)` tokens).
1294    fn int8_paged_decode_attention(
1295        _ctx: &mut Self::Context,
1296        _q: &Self::Buffer,
1297        _layer_k: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
1298        _layer_v: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
1299        _layer_k_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
1300        _layer_v_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
1301        _block_table: &Self::Buffer,
1302        _output: &mut Self::Buffer,
1303        _num_q_heads: usize,
1304        _num_kv_heads: usize,
1305        _head_dim: usize,
1306        _valid_kv_len: usize,
1307        _block_size: usize,
1308        _scale: f32,
1309    ) -> Result<()> {
1310        Err(FerrumError::unsupported(
1311            "int8_paged_decode_attention not implemented for this backend",
1312        ))
1313    }
1314}
1315
1316// Cpu/Metal NOT impl `BackendInt8KvOps` — the trait pivot to
1317// `KvLayer<B>` means `KvInt8: KvLayer<B>` only holds where
1318// `B: BackendInt8KvOps`, so `LlamaFamilyModel<CpuBackend, KvInt8>` is a
1319// compile error (no INT8 KvLayer impl satisfies it). Type system
1320// enforces the constraint without runtime stubs.