Skip to main content

ferrum_kernels/backend/
traits.rs

1//! Core Backend trait — the single abstraction over CUDA / Metal / CPU.
2
3use ferrum_types::{FerrumError, Result};
4
5pub use super::capabilities::{
6    BackendCollective, BackendGraph, BackendMoeFused, BackendQuantGguf, BackendQuantMarlin,
7};
8pub use super::types::MoeRouting;
9use super::types::{AttnConfig, KvCacheQuant, SrcDtype};
10
11/// Maximum decode-graph layer count. Per-layer call sites that share
12/// graph-captured host staging arrays use this as the stride between
13/// distinct slots. CUDA-only invariant (other backends ignore the
14/// `slot` argument); 64 covers all current LLM families up to and
15/// including Llama-3-70B (80 layers — but 70B doesn't run on a single
16/// 4090 anyway, so 64 is safe in practice for v0.2).
17pub const MAX_LAYERS_FOR_GRAPH: usize = 64;
18
19// Note: `TransformerConfig` / `AttnType` / `MlpType` / `RopeConfig` used to
20// live here when `ModelRunner` needed a generic model config. They're now
21// per-model (e.g. `Qwen3Config` in `ferrum-models::models::qwen3`) so each
22// model can carry exactly the architecture parameters it cares about.
23// Backend trait stays model-agnostic.
24
25/// The core abstraction over CUDA / Metal / CPU.
26///
27/// Key design: operations take a `&mut Self::Context` which accumulates work.
28///   - **CPU**: Context is `()` — ops execute immediately.
29///   - **Metal**: Context is a `CommandBuffer` — ops encode into it, flushed on `sync()`.
30///   - **CUDA**: Context is a `CudaStream` — ops launch on the stream, synced on `sync()`.
31///
32/// `layer_forward` passes the context through all ops in a layer.
33/// `ModelRunner` calls `sync()` only when it needs results (e.g., reading logits).
34pub trait Backend: Send + Sync + Sized + 'static {
35    type Buffer: Send + Sync;
36
37    /// Execution context that accumulates GPU work.
38    ///   - CPU: `()` (no-op, ops execute inline)
39    ///   - Metal: wraps a CommandBuffer
40    ///   - CUDA: wraps a CudaStream
41    type Context;
42
43    /// GPU-side timer scoped to this backend. See `super::timer` —
44    /// CPU: `Instant`; Metal: sync-wrap; CUDA: `cuEvent`.
45    /// PLAYBOOK § 1.1.
46    type Timer: super::timer::BackendTimer<Self>;
47
48    /// Factory for `Self::Timer` — exists so call sites that have a
49    /// `<B: Backend>` parameter can spawn a timer without importing the
50    /// concrete impl. PLAYBOOK § 1.2.
51    fn make_timer() -> Self::Timer;
52
53    /// True for the Apple Metal backend.
54    ///
55    /// Keep this as a backend capability instead of matching on type names in
56    /// model code. It is used only for backend-specific safety fallbacks where
57    /// a generic optimized path is known to be incorrect on one backend.
58    fn is_metal_backend() -> bool {
59        false
60    }
61
62    /// Opaque per-backend GPTQ weight representation.
63    ///   - CPU: dequantized f32 weights (run as regular GEMM)
64    ///   - Metal: `()` — unsupported; `gemm_gptq` errors
65    // Note (Phase 3e/4 + Phase C):
66    // - `type QuantStore` (GGUF k-quant storage) was removed in Phase 3e/4
67    //   — stacked-expert MoE GGUF goes through Box<dyn StackedExpertGgufLinear<Self>>
68    //   returned by `load_quant_experts`.
69    // - `type GptqStore` (Marlin/dequant GPTQ storage) was removed in Phase C
70    //   step 4e — stacked-expert Marlin MoE goes through
71    //   Arc<dyn MarlinExpertStack<Self>> returned by `load_gptq_stacked`,
72    //   and single-tensor GPTQ goes through Box<dyn Linear<Self>> returned
73    //   by `load_gptq`. Adding a new Marlin-capable backend is purely a
74    //   new MarlinExpertStack<NewBackend> impl — no Backend trait edits.
75
76    /// Create a new execution context (begin accumulating work).
77    fn new_context() -> Self::Context;
78
79    /// Flush accumulated work and wait for completion.
80    /// CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
81    fn sync(ctx: &mut Self::Context);
82
83    // Graph capability moved to the `BackendGraph` supertrait at the end
84    // of this file. CUDA implements its overrides; Metal/CPU inherit
85    // unsupported defaults via empty `impl BackendGraph for X {}` blocks.
86
87    // ── GPTQ (INT4 quantization) ────────────────────────────────────────
88    //
89    // Two-step: load (once per weight) → gemm (per forward). The store
90    // holds whatever backend-specific format is fastest; caller code
91    // (GptqLinear) is dtype-agnostic.
92
93    /// Zero the first `len` elements of a Self::Buffer. CUDA path uses
94    /// cuMemsetD16Async; default returns unsupported.
95    fn zero_buffer(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize) -> Result<()> {
96        Err(FerrumError::unsupported(
97            "zero_buffer not implemented for this backend",
98        ))
99    }
100
101    /// Phase D step 2+3: unified typed allocator. Replaces per-dtype
102    /// `alloc_u32` / `alloc_typed_i32` / etc. The buffer is dtype-
103    /// tagged at the wrapper level (`CudaBuf::U32`, `MetalBuf` with
104    /// `Dtype::U32`, `CpuBuf::U32`), so reads/writes through `.as_<T>()`
105    /// accessors get the correct byte count automatically.
106    fn alloc_typed(dtype: super::Dtype, n: usize) -> Self::Buffer;
107
108    /// Upload typed host data — replaces `from_slice_i32` /
109    /// `from_slice_u32` etc. The host element type `T` carries its
110    /// `Dtype` via the `HostDtype` marker so dispatch in the impl
111    /// is a one-line `match T::DTYPE`.
112    fn from_slice_typed<T: super::HostDtype>(data: &[T]) -> Self::Buffer;
113
114    /// In-place typed write — replaces `write_u32` / `write_i32_into`
115    /// / `write_f32_into`. The buffer must already be dtype-tagged
116    /// matching `T::DTYPE` (typically alloc'd via `alloc_typed` or
117    /// `from_slice_typed`).
118    fn write_typed<T: super::HostDtype>(
119        ctx: &mut Self::Context,
120        dst: &mut Self::Buffer,
121        data: &[T],
122    );
123
124    // ── GEMM ────────────────────────────────────────────────────────────
125
126    fn gemm(
127        ctx: &mut Self::Context,
128        a: &Self::Buffer,
129        b: &Self::Buffer,
130        out: &mut Self::Buffer,
131        m: usize,
132        n: usize,
133        k: usize,
134    );
135
136    // ── Norms ───────────────────────────────────────────────────────────
137
138    fn rms_norm(
139        ctx: &mut Self::Context,
140        x: &Self::Buffer,
141        w: &Self::Buffer,
142        eps: f32,
143        out: &mut Self::Buffer,
144        tokens: usize,
145        dim: usize,
146    );
147
148    fn fused_add_rms_norm(
149        ctx: &mut Self::Context,
150        residual: &mut Self::Buffer,
151        x: &Self::Buffer,
152        w: &Self::Buffer,
153        eps: f32,
154        out: &mut Self::Buffer,
155        tokens: usize,
156        dim: usize,
157    );
158
159    // ── Attention ───────────────────────────────────────────────────────
160
161    fn flash_attention(
162        ctx: &mut Self::Context,
163        q: &Self::Buffer,
164        k: &Self::Buffer,
165        v: &Self::Buffer,
166        out: &mut Self::Buffer,
167        batch: usize,
168        q_len: usize,
169        kv_len: usize,
170        pos_offset: usize,
171        cfg: &AttnConfig,
172    );
173
174    /// Multi-Head Latent Attention — DeepSeek V2 / V3's compressed-KV
175    /// attention variant. Extension point only; no backend implements it
176    /// yet. DeepSeek V3 landing in Phase D/E will fill this in.
177    ///
178    /// `q`: full Q `[batch, num_heads, q_len, head_dim]`
179    /// `kv_compressed`: latent KV `[batch, kv_len, kv_lora_rank]`
180    /// `kv_rope`: per-position rope-applied key heads `[batch, kv_len, qk_rope_head_dim]`
181    /// `out`: `[batch, num_heads, q_len, head_dim]`
182    #[allow(clippy::too_many_arguments)]
183    fn mla_attention(
184        _ctx: &mut Self::Context,
185        _q: &Self::Buffer,
186        _kv_compressed: &Self::Buffer,
187        _kv_rope: &Self::Buffer,
188        _out: &mut Self::Buffer,
189        _batch: usize,
190        _q_len: usize,
191        _kv_len: usize,
192        _pos_offset: usize,
193        _cfg: &AttnConfig,
194        _kv_lora_rank: usize,
195        _qk_rope_head_dim: usize,
196    ) -> Result<()> {
197        Err(FerrumError::unsupported(
198            "mla_attention not implemented for this backend; required by \
199             DeepSeek V2/V3 (Phase D/E)",
200        ))
201    }
202
203    // ── Element-wise ────────────────────────────────────────────────────
204    //
205    // Models use `add_inplace` for residual updates and `copy_slice` for the
206    // row-extraction step in prefill. Offset-free copy / non-inplace add are
207    // not needed by the current Model-as-Code path; they can return later if
208    // a model actually requires them.
209
210    /// Copy `len` floats from `src[src_offset..]` to `dst[dst_offset..]`.
211    ///
212    /// Needed for Qwen3Model::prefill to pluck the last token's hidden state
213    /// out of `residual[seq_len, h]` without round-tripping through host RAM.
214    /// `Backend::copy` is the offset-free variant; `copy_slice` additionally
215    /// supports non-zero source and destination offsets.
216    fn copy_slice(
217        ctx: &mut Self::Context,
218        src: &Self::Buffer,
219        src_offset: usize,
220        dst: &mut Self::Buffer,
221        dst_offset: usize,
222        len: usize,
223    );
224
225    // ── Embedding ───────────────────────────────────────────────────────
226
227    fn embedding_lookup(
228        ctx: &mut Self::Context,
229        table: &Self::Buffer,
230        ids: &[u32],
231        out: &mut Self::Buffer,
232        dim: usize,
233    );
234
235    /// Device-buffer variant of `embedding_lookup` for graph-capturable
236    /// MoE routing — the gather step before phase-1 GEMM in
237    /// `moe_forward_bucketed`. The host-slice `embedding_lookup` does
238    /// `clone_htod(ids)` internally, which records stale host pointers
239    /// under CUDA Graph capture replay.
240    ///
241    /// `ids: &Self::Buffer` must be a device I32 buffer of `batch`
242    /// elements (e.g. `Qwen3MoeScratch::route_packed_idx_dev`).
243    /// `batch` is passed explicitly since a typed CudaBuf carries
244    /// its element count but the caller often wants a partial gather.
245    ///
246    /// Default impl: round-trip via `to_vec` + dispatch the host-slice
247    /// variant. CUDA overrides.
248    fn embedding_lookup_dev(
249        ctx: &mut Self::Context,
250        table: &Self::Buffer,
251        ids: &Self::Buffer,
252        out: &mut Self::Buffer,
253        batch: usize,
254        dim: usize,
255    ) {
256        // Default: round-trip. CUDA overrides with a direct device-arg
257        // kernel launch (no clone_htod).
258        let ids_host_f32 = Self::to_vec(ids, batch);
259        let ids_host_u32: Vec<u32> = ids_host_f32.iter().map(|x| x.to_bits()).collect();
260        Self::embedding_lookup(ctx, table, &ids_host_u32, out, dim);
261    }
262
263    // ── Transformer-specific fused ops ─────────────────────────────────
264    // These avoid CPU round-trips for data layout transformations.
265
266    /// Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers.
267    /// Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
268    fn split_qkv(
269        ctx: &mut Self::Context,
270        qkv: &Self::Buffer,
271        q: &mut Self::Buffer,
272        k: &mut Self::Buffer,
273        v: &mut Self::Buffer,
274        tokens: usize,
275        q_dim: usize,
276        kv_dim: usize,
277    );
278
279    /// Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im],
280    /// then compute SiLU(gate) * up → out [tokens, im].
281    fn fused_silu_mul_split(
282        ctx: &mut Self::Context,
283        gate_up: &Self::Buffer,
284        out: &mut Self::Buffer,
285        tokens: usize,
286        im: usize,
287    );
288
289    /// Fused QK-norm + RoPE + transpose-to-head-major.
290    ///
291    /// `mode` selects the operation:
292    ///   0 = transpose only (typical for V, which needs no norm and no RoPE)
293    ///   1 = per-head RMS norm + RoPE + transpose  (Q/K with QK-norm, Qwen3)
294    ///   2 = RoPE + transpose                       (Q/K without QK-norm, Llama/Mistral)
295    ///
296    /// input:   `[tokens, heads, head_dim]`  (token-major, output of split_qkv)
297    /// output:  `[heads, tokens, head_dim]`  (head-major, ready for flash_attn / kv_cache_append)
298    ///
299    /// `pos_offset` is the position of token 0 (decode uses current seq len;
300    /// prefill uses 0). Within the batch, positions are taken as `pos_offset + i`.
301    ///
302    /// This is the primary attention-input preparation op. Backends that have a
303    /// fused kernel (Metal's `qk_norm_rope_transpose_f32`) will be dramatically
304    /// faster than composing norm + rope + transpose separately; the CPU
305    /// fallback lowers to the individual ops.
306    #[allow(clippy::too_many_arguments)]
307    fn qk_norm_rope(
308        ctx: &mut Self::Context,
309        input: &Self::Buffer,
310        norm_w: &Self::Buffer,
311        cos: &Self::Buffer,
312        sin: &Self::Buffer,
313        output: &mut Self::Buffer,
314        tokens: usize,
315        heads: usize,
316        head_dim: usize,
317        pos_offset: usize,
318        eps: f32,
319        mode: i32,
320    );
321
322    /// Batched kv_cache_append across M caches in one launch. Each item
323    /// writes its (head-major) K-or-V row into its own cache at offset
324    /// read from `cache_lens[i]`. Replaces M sequential
325    /// `kv_cache_append_head_major` calls with a single dispatch.
326    ///
327    /// `new_data` layout: `[m, nkv, hd]` item-major (each item's slice
328    /// is contiguous, identical to the `k/v_normed_batched` produced by
329    /// `qk_norm_rope_batched_per_item`).
330    /// `caches`: per-cache `[nkv, capacity, hd]` head-major.
331    /// `cache_lens`: device buffer (u32 storage, length ≥ m). Caller
332    /// fills via `B::write_u32_into` BEFORE the call. Required for
333    /// CUDA-graph capture: the kernel reads from this stable device
334    /// buffer, so a captured graph can be replayed with new lens by
335    /// just rewriting the buffer between launches.
336    fn kv_cache_append_batched_per_cache(
337        _ctx: &mut Self::Context,
338        _caches: &[&Self::Buffer],
339        _new_data: &Self::Buffer,
340        _cache_lens: &Self::Buffer,
341        _capacity: usize,
342        _m: usize,
343        _nkv: usize,
344        _hd: usize,
345        _slot: usize,
346    ) -> Result<()> {
347        Err(FerrumError::unsupported(
348            "kv_cache_append_batched_per_cache not implemented for this backend",
349        ))
350    }
351
352    /// Batched flash_attention across M decode caches in one launch.
353    /// Replaces the per-item `flash_attention(q_len=1, ...)` × M
354    /// loop in the non-paged batched-decode path.
355    ///
356    /// API takes Vec<&Buffer> for the per-cache K/V buffers (each
357    /// `[nkv, capacity, hd]` head-major) plus host-side `kv_lens`.
358    /// Backends that implement it must extract per-cache device
359    /// pointers, build the device arrays the kernel needs, and launch
360    /// one kernel covering all M items.
361    ///
362    /// `q` layout: [m, nq, hd] item-major (matches the
363    /// `qk_norm_rope_batched_per_item` output for q_len=1).
364    /// `out` layout: [m, nq, hd] item-major — written directly into
365    /// the caller's batched attn_out buffer, no per-item copy needed.
366    ///
367    /// CUDA-only for now (kernel `batched_decode_attention` exists in
368    /// `kernels/batched_decode_attention.cu`).
369    /// `kv_lens`: device buffer (u32 storage, length ≥ m) — same
370    /// design as `kv_cache_append_batched_per_cache::cache_lens`.
371    fn flash_attention_batched_per_cache(
372        _ctx: &mut Self::Context,
373        _q: &Self::Buffer,
374        _k_caches: &[&Self::Buffer],
375        _v_caches: &[&Self::Buffer],
376        _kv_lens: &Self::Buffer,
377        _out: &mut Self::Buffer,
378        _nq: usize,
379        _nkv: usize,
380        _hd: usize,
381        _scale: f32,
382        _max_valid_kv: usize,
383        _capacity: usize,
384        _slot: usize,
385    ) -> Result<()> {
386        Err(FerrumError::unsupported(
387            "flash_attention_batched_per_cache not implemented for this backend",
388        ))
389    }
390
391    /// Batched per-item-position variant of `qk_norm_rope` for the
392    /// non-paged batched-decode path. Each of the `m` items has its own
393    /// absolute RoPE position (read from a device i32 buffer of length
394    /// `m`). Layout is item-major in *both* input and output:
395    ///
396    ///   input  [m, heads, head_dim]
397    ///   output [m, heads, head_dim]   (no head-major transpose)
398    ///
399    /// Item-major output keeps the per-item flash_attention slice
400    /// contiguous (`output[i * heads * head_dim ..]` is item i's whole
401    /// Q tensor in head-major-equivalent layout for q_len=1).
402    ///
403    /// Replaces the M sequential single-item launches in the existing
404    /// `forward_layer_batched_decode` path with one batched dispatch.
405    /// CUDA-only for now; other backends fall through to the default
406    /// `unsupported` and the caller falls back to the per-item loop.
407    fn qk_norm_rope_batched_per_item(
408        _ctx: &mut Self::Context,
409        _input: &Self::Buffer,
410        _norm_w: &Self::Buffer,
411        _cos: &Self::Buffer,
412        _sin: &Self::Buffer,
413        _output: &mut Self::Buffer,
414        _positions: &Self::Buffer,
415        _m: usize,
416        _heads: usize,
417        _head_dim: usize,
418        _eps: f32,
419        _mode: i32,
420    ) -> Result<()> {
421        Err(FerrumError::unsupported(
422            "qk_norm_rope_batched_per_item not implemented for this backend",
423        ))
424    }
425
426    /// Fused split-QKV + QK-norm + RoPE + head-major transpose.
427    ///
428    /// Single-dispatch replacement for the (`split_qkv` → 3× `qk_norm_rope`)
429    /// chain on the decode-attention prelude. Reads the linear-layer
430    /// fused-QKV output once and writes head-major Q/K/V directly into
431    /// attention scratch.
432    ///
433    /// `qkv` layout: `[tokens, q_heads*hd + 2*kv_heads*hd]`.
434    /// `q_out`: `[q_heads, tokens, hd]`. `k_out`/`v_out`: `[kv_heads, tokens, hd]`.
435    /// `qk_mode`: 1 = norm + half-split RoPE for Q/K (Qwen3 with QK-norm),
436    ///            2 = half-split RoPE only for Q/K,
437    ///            3 = interleaved RoPE only for Q/K (GGUF LLaMA / llama.cpp layout).
438    /// V always falls through to transpose-only.
439    ///
440    /// Default returns Unsupported. Backends that implement it are
441    /// expected to be dramatically faster than the four-dispatch chain.
442    #[allow(clippy::too_many_arguments)]
443    fn split_qkv_norm_rope(
444        _ctx: &mut Self::Context,
445        _qkv: &Self::Buffer,
446        _q_norm_w: &Self::Buffer,
447        _k_norm_w: &Self::Buffer,
448        _cos: &Self::Buffer,
449        _sin: &Self::Buffer,
450        _q_out: &mut Self::Buffer,
451        _k_out: &mut Self::Buffer,
452        _v_out: &mut Self::Buffer,
453        _tokens: usize,
454        _q_heads: usize,
455        _kv_heads: usize,
456        _head_dim: usize,
457        _pos_offset: usize,
458        _eps: f32,
459        _qk_mode: i32,
460    ) -> Result<()> {
461        Err(FerrumError::unsupported(
462            "split_qkv_norm_rope not implemented for this backend",
463        ))
464    }
465
466    /// Variant of [`Backend::split_qkv_norm_rope`] that writes the new
467    /// K and V directly into pre-allocated head-major KV cache buffers
468    /// at slot `[kv_heads, cache_len .. cache_len + tokens, hd]`.
469    /// Eliminates the trailing `kv_cache_append_head_major` dispatch on
470    /// the decode hot path. Q still lands in per-token head-major
471    /// scratch (flash-attention reads it as the query).
472    ///
473    /// Default returns Unsupported. Backends without the fused kernel
474    /// can keep using `split_qkv_norm_rope` + `kv_cache_append_head_major`.
475    #[allow(clippy::too_many_arguments)]
476    fn split_qkv_norm_rope_into_cache(
477        _ctx: &mut Self::Context,
478        _qkv: &Self::Buffer,
479        _q_norm_w: &Self::Buffer,
480        _k_norm_w: &Self::Buffer,
481        _cos: &Self::Buffer,
482        _sin: &Self::Buffer,
483        _q_out: &mut Self::Buffer,
484        _cache_k: &mut Self::Buffer,
485        _cache_v: &mut Self::Buffer,
486        _tokens: usize,
487        _q_heads: usize,
488        _kv_heads: usize,
489        _head_dim: usize,
490        _pos_offset: usize,
491        _eps: f32,
492        _qk_mode: i32,
493        _cache_len: usize,
494        _cache_capacity: usize,
495    ) -> Result<()> {
496        Err(FerrumError::unsupported(
497            "split_qkv_norm_rope_into_cache not implemented for this backend",
498        ))
499    }
500
501    // Phase D step 2: alloc_u32 / write_u32 deleted. Callers use the
502    // unified `alloc_typed(Dtype::U32, n)` + `write_typed(&[u32])` API
503    // declared above.
504
505    /// Append new K/V into a pre-allocated head-major cache buffer.
506    ///
507    /// `cache_k` / `cache_v`: `[nkv, capacity, hd]` (head-major, pre-allocated)
508    /// `new_k_head_major` / `new_v_head_major`: `[nkv, new_tokens, hd]`
509    ///   — produced directly by `qk_norm_rope`, no extra transpose needed.
510    ///
511    /// In-place append at slot `[nkv, cache_len..cache_len+new_tokens, hd]`.
512    /// Caller owns `cache_len` bookkeeping.
513    #[allow(clippy::too_many_arguments)]
514    fn kv_cache_append_head_major(
515        ctx: &mut Self::Context,
516        cache_k: &mut Self::Buffer,
517        cache_v: &mut Self::Buffer,
518        cache_len: usize,
519        cache_capacity: usize,
520        new_k_head_major: &Self::Buffer,
521        new_v_head_major: &Self::Buffer,
522        new_tokens: usize,
523        nkv: usize,
524        hd: usize,
525    );
526
527    /// Transpose [heads, tokens, dim] → [tokens, heads, dim].
528    /// Called after `flash_attention` to restore token-major layout for O-proj.
529    fn transpose_head_to_token(
530        ctx: &mut Self::Context,
531        src: &Self::Buffer,
532        dst: &mut Self::Buffer,
533        tokens: usize,
534        heads: usize,
535        dim: usize,
536    );
537
538    /// Inverse of `transpose_head_to_token`: [tokens, heads, dim] →
539    /// [heads, tokens, dim]. Used by the CUDA `paged_decode_attention`
540    /// wrapper to convert `paged_varlen_attention`'s token-major output
541    /// back to the head-major layout that Qwen3MoeModel expects.
542    /// Default panics — backends without a paged-KV CUDA path don't
543    /// hit this code.
544    fn transpose_token_to_head(
545        _ctx: &mut Self::Context,
546        _src: &Self::Buffer,
547        _dst: &mut Self::Buffer,
548        _tokens: usize,
549        _heads: usize,
550        _dim: usize,
551    ) {
552        panic!("transpose_token_to_head not implemented for this backend");
553    }
554
555    /// residual[i] += x[i] (in-place)
556    fn add_inplace(
557        ctx: &mut Self::Context,
558        residual: &mut Self::Buffer,
559        x: &Self::Buffer,
560        len: usize,
561    );
562
563    /// `dst[i] += scale * src[i]` — scalar-broadcast scaled add, in place.
564    ///
565    /// MoE per-token combine writes `out[b] += weight_k * expert_k(x[b])`
566    /// for each top-K expert; this primitive is the per-call accumulate.
567    /// Backends without a dedicated kernel can fall back to the default
568    /// implementation, which round-trips through host memory — correct,
569    /// but slow on a hot path. Override on any backend you actually
570    /// dispatch MoE on.
571    fn scaled_add_inplace(
572        _ctx: &mut Self::Context,
573        dst: &mut Self::Buffer,
574        src: &Self::Buffer,
575        scale: f32,
576        len: usize,
577    ) {
578        let mut dst_v = Self::to_vec(dst, len);
579        let src_v = Self::to_vec(src, len);
580        for i in 0..len {
581            dst_v[i] += scale * src_v[i];
582        }
583        // Move the new buffer into the slot pointed to by `dst`. Safe
584        // because `Self::Buffer: Send + Sync` and the old buffer is
585        // dropped here when overwritten.
586        *dst = Self::from_slice(&dst_v);
587    }
588
589    /// Strided variant of [`Backend::fused_silu_mul_split`] for the
590    /// bucketed MoE path: reads `gate_up` rows starting at
591    /// `in_row_offset`, writes `out` rows starting at `out_row_offset`.
592    #[allow(clippy::too_many_arguments)]
593    fn fused_silu_mul_split_strided(
594        _ctx: &mut Self::Context,
595        _gate_up: &Self::Buffer,
596        _in_row_offset: usize,
597        _out: &mut Self::Buffer,
598        _out_row_offset: usize,
599        _tokens: usize,
600        _intermediate: usize,
601    ) {
602        unimplemented!("fused_silu_mul_split_strided default impl missing");
603    }
604
605    /// Broadcast bias add: `data[r, c] += bias[c]` for every row.
606    /// Required by Bert / Clip / Whisper whose linear projections carry a bias.
607    fn add_bias(
608        ctx: &mut Self::Context,
609        data: &mut Self::Buffer,
610        bias: &Self::Buffer,
611        rows: usize,
612        cols: usize,
613    );
614
615    /// Full LayerNorm (mean + variance normalisation + affine), distinct from
616    /// the `rms_norm` used by Llama-family decoders.
617    ///   `out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]`
618    /// Where `mean` and `var` are reduced over the last dim (cols).
619    #[allow(clippy::too_many_arguments)]
620    fn layer_norm(
621        ctx: &mut Self::Context,
622        x: &Self::Buffer,
623        gamma: &Self::Buffer,
624        beta: &Self::Buffer,
625        eps: f32,
626        out: &mut Self::Buffer,
627        tokens: usize,
628        dim: usize,
629    );
630
631    /// Element-wise GELU activation (erf-based, matches PyTorch default).
632    fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
633
634    // ── Buffer management (context-free) ────────────────────────────────
635
636    fn alloc(len: usize) -> Self::Buffer;
637    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
638    fn from_slice(data: &[f32]) -> Self::Buffer;
639
640    /// Greedy-decode fast path: GPU argmax over each row of a
641    /// `[m, n]` FP16 logits buffer, returning the m token indices on the
642    /// host. Saves `m × n × 2` bytes of D2H per call (e.g. 19.5 MB at
643    /// c=32, vocab=152064) and the host-side argmax scan (~150 µs × m).
644    ///
645    /// Default impl falls back to the slow path: full `to_vec` + host
646    /// argmax. CUDA overrides with a native kernel + tiny D2H (m × 4 B).
647    /// Backends that don't override pay the same cost as
648    /// `to_vec` + host argmax, so callers can call this unconditionally.
649    fn argmax_rows_f16(
650        _ctx: &mut Self::Context,
651        logits: &Self::Buffer,
652        m: usize,
653        n: usize,
654    ) -> Result<Vec<u32>> {
655        let host = Self::to_vec(logits, m * n);
656        let mut out = Vec::with_capacity(m);
657        for row in 0..m {
658            let slice = &host[row * n..(row + 1) * n];
659            let mut max_idx = 0usize;
660            let mut max_val = f32::NEG_INFINITY;
661            for (i, &v) in slice.iter().enumerate() {
662                if v > max_val {
663                    max_val = v;
664                    max_idx = i;
665                }
666            }
667            out.push(max_idx as u32);
668        }
669        Ok(out)
670    }
671
672    /// Load a weight tensor straight from its on-disk byte representation,
673    /// letting the backend pick its preferred storage dtype.
674    ///
675    /// Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching
676    /// pre-existing loader behaviour. Backends override this to go straight
677    /// from raw bytes into a native half-precision buffer (e.g. Metal with
678    /// `FERRUM_METAL_DTYPE=f16`), avoiding the transient 2× RAM spike.
679    fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
680        let data = src_dtype.to_f32_vec(raw);
681        Self::from_slice(&data)
682    }
683
684    // (The Phase A3 unified `gemm_quant(QuantWeights, QuantKind)` stub
685    // that used to live here is superseded by the `load_quant` /
686    // `gemm_quant(QuantStore)` pair earlier in this trait — same idea,
687    // but the store hides the per-kind buffer layout so callers don't
688    // have to construct a per-kind `QuantWeights<'_, Self>` packet.)
689}
690
691// ════════════════════════════════════════════════════════════════════════
692// BackendPagedKv capability (vLLM-style paged KV cache + paged attention)
693// ════════════════════════════════════════════════════════════════════════
694//
695// Paged KV pool with block-table indirection, plus the paged attention
696// kernel variants that read through that indirection. CUDA + Metal both
697// implement the real kernels; CPU `impl BackendPagedKv for CpuBackend {}`
698// inherits unsupported defaults.
699
700/// Capability-trait for backends that support paged KV cache + paged attention.
701pub trait BackendPagedKv: Backend {
702    /// Whether this backend has a paged-KV decode path
703    /// (`paged_decode_attention` etc.). Currently true for Metal, false
704    /// for CPU. Used to decide the default of `FERRUM_METAL_PAGED_KV` —
705    /// the `serve` path should opt in automatically when supported so
706    /// users get the bench-quality concurrent-decode numbers without
707    /// having to learn the flag.
708    fn supports_paged_kv() -> bool {
709        false
710    }
711    /// Pre-populate the per-slot device-pointer scratch arrays used by
712    /// the batched kernels (`kv_cache_append_batched_per_cache` and
713    /// `flash_attention_batched_per_cache`). Required by the CUDA-graph
714    /// capture path: the captured graph contains only kernel launches
715    /// (no captured `memcpy_htod`), so the device scratch must be fresh
716    /// when the graph replays.
717    ///
718    /// Caller passes flat layer-major slices: `k_caches[li * m + i]` and
719    /// `v_caches[li * m + i]`. Backend extracts each cache's device
720    /// pointer and writes into its corresponding slot in the device
721    /// scratch via SYNCHRONOUS memcpy (not captured by stream capture).
722    ///
723    /// CUDA-only; other backends fall through to the default
724    /// `unsupported` and the caller skips the population call.
725    fn populate_batched_pointers(
726        _ctx: &mut Self::Context,
727        _k_caches: &[&Self::Buffer],
728        _v_caches: &[&Self::Buffer],
729        _num_layers: usize,
730        _m: usize,
731    ) -> Result<()> {
732        Err(FerrumError::unsupported(
733            "populate_batched_pointers not implemented for this backend",
734        ))
735    }
736    /// Paged-KV variant of [`Self::split_qkv_norm_rope_into_cache`].
737    ///
738    /// Same fused split + qk-norm + RoPE, but K/V are written into a
739    /// paged pool `[num_blocks, kv_heads, block_size, head_dim]`
740    /// indexed via `block_table[logical_block]` → physical_block.
741    /// Q still goes to head-major scratch.
742    ///
743    /// Default returns Unsupported. Backends that lack a paged kernel
744    /// keep using the contiguous variant.
745    /// `qkv_byte_offset` / `q_out_byte_offset` let the caller pass a
746    /// slice of a larger batched buffer (used by the multi-seq paged
747    /// path in `decode_batch_internal`). For single-seq dispatch they
748    /// should be 0.
749    #[allow(clippy::too_many_arguments)]
750    fn split_qkv_norm_rope_into_paged_cache(
751        _ctx: &mut Self::Context,
752        _qkv: &Self::Buffer,
753        _qkv_byte_offset: u64,
754        _q_norm_w: &Self::Buffer,
755        _k_norm_w: &Self::Buffer,
756        _cos: &Self::Buffer,
757        _sin: &Self::Buffer,
758        _q_out: &mut Self::Buffer,
759        _q_out_byte_offset: u64,
760        _cache_k: &mut Self::Buffer,
761        _cache_v: &mut Self::Buffer,
762        _block_table: &Self::Buffer,
763        _tokens: usize,
764        _q_heads: usize,
765        _kv_heads: usize,
766        _head_dim: usize,
767        _pos_offset: usize,
768        _eps: f32,
769        _qk_mode: i32,
770        _cache_len: usize,
771        _block_size: usize,
772        _max_num_blocks_per_seq: usize,
773    ) -> Result<()> {
774        Err(FerrumError::unsupported(
775            "split_qkv_norm_rope_into_paged_cache not implemented for this backend",
776        ))
777    }
778    /// Paged-KV variant of [`Self::flash_attention`].
779    ///
780    /// Decode (`q_len == 1`):
781    ///   `q`/`out`: `[num_seqs, num_heads, head_dim]` (token-major)
782    ///
783    /// Causal prefill (`q_len > 1`, single seq):
784    ///   `q`/`out`: `[num_heads, q_len, head_dim]` (head-major — the
785    ///              layout produced by `split_qkv_norm_rope_into_paged_cache`)
786    ///   The kernel applies a per-q-token causal mask using
787    ///   `context_lens[seq]` as the FINAL kv_len (= `pos_offset + q_len`):
788    ///   token i sees positions `[0, context_lens - q_len + 1 + i)`.
789    ///
790    /// Common to both:
791    ///   `k_pool`/`v_pool`: `[num_blocks, num_kv_heads, block_size, head_dim]`
792    ///   `block_tables`: `[num_seqs, max_num_blocks_per_seq]` u32
793    ///   `context_lens`: `[num_seqs]` u32
794    ///
795    /// Backends without a paged kernel return Unsupported; callers are
796    /// expected to fall back to contiguous KV.
797    #[allow(clippy::too_many_arguments)]
798    fn paged_decode_attention(
799        _ctx: &mut Self::Context,
800        _q: &Self::Buffer,
801        _k_pool: &Self::Buffer,
802        _v_pool: &Self::Buffer,
803        _out: &mut Self::Buffer,
804        _block_tables: &Self::Buffer,
805        _context_lens: &Self::Buffer,
806        _num_seqs: usize,
807        _num_heads: usize,
808        _num_kv_heads: usize,
809        _head_dim: usize,
810        _block_size: usize,
811        _max_num_blocks_per_seq: usize,
812        _q_len: usize,
813    ) -> Result<()> {
814        Err(FerrumError::unsupported(
815            "paged_decode_attention not implemented for this backend",
816        ))
817    }
818    /// Capability: does this backend implement
819    /// `split_qkv_norm_rope_into_paged_cache_varlen` and
820    /// `paged_varlen_attention`? Required by the unified mixed-batch
821    /// forward path used by `LlamaFamilyModel::unified_forward`. Default
822    /// false; backends that ship the varlen kernels override.
823    fn supports_varlen_qkv() -> bool {
824        false
825    }
826    /// Varlen variant of [`Self::split_qkv_norm_rope_into_paged_cache`].
827    ///
828    /// Single launch covering ALL sequences in the batch. Reads
829    /// `pos_offsets[seq]`, `cu_seqlens_q[seq]`, and the per-seq
830    /// block_table from device buffers — graph-capturable (the per-iter
831    /// state is in buffers, not kernel scalars). Replaces the per-item
832    /// dispatch loop in `unified_forward_layer` with one call.
833    ///
834    /// Layouts:
835    /// - `qkv`: `[m_total, q_dim + 2 * kv_dim]` token-major
836    /// - `q_out`: `[m_total, q_heads, head_dim]` token-major (matches
837    ///   what `paged_varlen_attention` reads)
838    /// - `cache_k` / `cache_v`: paged pool same as `paged_varlen_attention`
839    /// - `cu_seqlens_q`: `[num_seqs + 1]` u32 prefix sum
840    /// - `pos_offsets`: `[num_seqs]` u32, starting kv_pos per seq
841    /// - `block_tables`: `[num_seqs, max_blocks_per_seq]` i32 stacked
842    #[allow(clippy::too_many_arguments)]
843    fn split_qkv_norm_rope_into_paged_cache_varlen(
844        _ctx: &mut Self::Context,
845        _qkv: &Self::Buffer,
846        _q_norm_w: &Self::Buffer,
847        _k_norm_w: &Self::Buffer,
848        _cos: &Self::Buffer,
849        _sin: &Self::Buffer,
850        _q_out: &mut Self::Buffer,
851        _cache_k: &mut Self::Buffer,
852        _cache_v: &mut Self::Buffer,
853        _cu_seqlens_q: &Self::Buffer,
854        _pos_offsets: &Self::Buffer,
855        _block_tables: &Self::Buffer,
856        _num_seqs: usize,
857        _m_total: usize,
858        _q_heads: usize,
859        _kv_heads: usize,
860        _head_dim: usize,
861        _eps: f32,
862        _qk_mode: i32,
863        _block_size: usize,
864        _max_blocks_per_seq: usize,
865    ) -> Result<()> {
866        Err(FerrumError::unsupported(
867            "split_qkv_norm_rope_into_paged_cache_varlen not implemented for this backend",
868        ))
869    }
870    /// Variable-length paged attention with GQA + causal mask.
871    ///
872    /// Supports a unified mixed batch where each sequence contributes
873    /// 1 (decode) or N (prefill chunk) query tokens — the workhorse for
874    /// chunked-prefill. See `kernels/paged_varlen_attention.cu` for the
875    /// kernel itself.
876    ///
877    /// Layouts:
878    /// - `q` / `out`: `[total_q_tokens, num_heads, head_dim]` (token-
879    ///   major, FP16). `total_q_tokens` = `cu_seqlens_q[num_seqs]`.
880    /// - `k_pool` / `v_pool`: paged block pool, layout matches
881    ///   `paged_decode_attention`.
882    /// - `cu_seqlens_q`: `[num_seqs + 1]` u32 prefix sum, with
883    ///   `cu_seqlens_q[0] = 0` and `cu_seqlens_q[num_seqs] = total_q_tokens`.
884    /// - `pos_offsets`: `[num_seqs]` u32, the starting absolute KV
885    ///   position of each seq's first q token (= prior `kv_len`).
886    /// - `block_tables`: `[num_seqs, max_num_blocks_per_seq]` i32 grid.
887    ///
888    /// Each query token attends causally to all KV positions
889    /// `[0, pos_offsets[s] + local_idx]`.
890    #[allow(clippy::too_many_arguments)]
891    fn paged_varlen_attention(
892        _ctx: &mut Self::Context,
893        _q: &Self::Buffer,
894        _k_pool: &Self::Buffer,
895        _v_pool: &Self::Buffer,
896        _out: &mut Self::Buffer,
897        _cu_seqlens_q: &Self::Buffer,
898        _pos_offsets: &Self::Buffer,
899        _block_tables: &Self::Buffer,
900        _num_seqs: usize,
901        _total_q_tokens: usize,
902        _max_kv_len: usize,
903        _num_heads: usize,
904        _num_kv_heads: usize,
905        _head_dim: usize,
906        _block_size: usize,
907        _max_num_blocks_per_seq: usize,
908    ) -> Result<()> {
909        Err(FerrumError::unsupported(
910            "paged_varlen_attention not implemented for this backend",
911        ))
912    }
913
914    /// Opt-in vLLM FlashAttention-2 FFI path for FA-layout paged KV.
915    ///
916    /// This is intentionally separate from [`Self::paged_varlen_attention`]:
917    /// it needs the final per-sequence KV lengths (`seq_lens`) and an explicit
918    /// LSE scratch buffer because the external FA2 runner writes softmax LSE.
919    /// Default returns Err(unsupported); CUDA overrides when a runtime shim is
920    /// provided via `FERRUM_FA2_DIRECT_FFI_SHIM`.
921    #[allow(clippy::too_many_arguments)]
922    fn paged_varlen_attention_fa2_ffi(
923        _ctx: &mut Self::Context,
924        _q: &Self::Buffer,
925        _k_pool: &Self::Buffer,
926        _v_pool: &Self::Buffer,
927        _out: &mut Self::Buffer,
928        _lse: &mut Self::Buffer,
929        _cu_seqlens_q: &Self::Buffer,
930        _seq_lens: &Self::Buffer,
931        _block_tables: &Self::Buffer,
932        _num_seqs: usize,
933        _total_q_tokens: usize,
934        _max_q_len: usize,
935        _max_kv_len: usize,
936        _num_heads: usize,
937        _num_kv_heads: usize,
938        _head_dim: usize,
939        _block_size: usize,
940        _max_num_blocks_per_seq: usize,
941    ) -> Result<()> {
942        Err(FerrumError::unsupported(
943            "paged_varlen_attention_fa2_ffi not implemented for this backend",
944        ))
945    }
946
947    /// Batched paged decode attention — multi-seq, single token per seq.
948    /// Faster path for the unified_forward layer when m_total == num_seqs
949    /// (every item is a single-token decode). Skips the cu_seqlens_q
950    /// linear scan that `paged_varlen_attention` does in the fully-mixed
951    /// case.
952    ///
953    /// Layouts:
954    ///   q              : [num_seqs, num_q_heads, head_dim]
955    ///   k_pool/v_pool  : paged pool (same as paged_varlen)
956    ///   block_tables   : [num_seqs, max_num_blocks_per_seq]
957    ///   valid_kv_lens  : [num_seqs] — current kv_len per seq
958    ///   out            : [num_seqs, num_q_heads, head_dim]
959    ///
960    /// Default returns Err(unsupported); CUDA backend overrides.
961    #[allow(clippy::too_many_arguments)]
962    fn paged_batched_decode_attention(
963        _ctx: &mut Self::Context,
964        _q: &Self::Buffer,
965        _k_pool: &Self::Buffer,
966        _v_pool: &Self::Buffer,
967        _out: &mut Self::Buffer,
968        _block_tables: &Self::Buffer,
969        _valid_kv_lens: &Self::Buffer,
970        _num_seqs: usize,
971        _max_kv_len: usize,
972        _num_heads: usize,
973        _num_kv_heads: usize,
974        _head_dim: usize,
975        _block_size: usize,
976        _max_num_blocks_per_seq: usize,
977    ) -> Result<()> {
978        Err(FerrumError::unsupported(
979            "paged_batched_decode_attention not implemented for this backend",
980        ))
981    }
982
983    /// Capability: backend has vLLM-layout paged KV write kernels and the
984    /// `paged_attention_v2` decode kernel. Models that opt into this layout
985    /// at construction time (via `FERRUM_USE_VLLM_PAGED_ATTN=1`) must
986    /// dispatch ALL paged writes and reads through the `_vllm` variants —
987    /// the layouts are not compatible. Default `false`.
988    fn supports_vllm_paged_attn() -> bool {
989        false
990    }
991
992    /// vLLM-layout variant of
993    /// [`Self::split_qkv_norm_rope_into_paged_cache`]. K/V are written in
994    /// vLLM's `paged_attention_v2` layout: K is
995    /// `[num_blocks, kv_heads, head_dim/x, block_size, x]` (x = 16/sizeof(elem)),
996    /// V is `[num_blocks, kv_heads, head_dim, block_size]`. Q output and
997    /// every other argument matches the non-vllm variant exactly so the
998    /// model layer can swap dispatchers based on a single flag.
999    #[allow(clippy::too_many_arguments)]
1000    fn split_qkv_norm_rope_into_paged_cache_vllm(
1001        _ctx: &mut Self::Context,
1002        _qkv: &Self::Buffer,
1003        _qkv_byte_offset: u64,
1004        _q_norm_w: &Self::Buffer,
1005        _k_norm_w: &Self::Buffer,
1006        _cos: &Self::Buffer,
1007        _sin: &Self::Buffer,
1008        _q_out: &mut Self::Buffer,
1009        _q_out_byte_offset: u64,
1010        _cache_k: &mut Self::Buffer,
1011        _cache_v: &mut Self::Buffer,
1012        _block_table: &Self::Buffer,
1013        _tokens: usize,
1014        _q_heads: usize,
1015        _kv_heads: usize,
1016        _head_dim: usize,
1017        _pos_offset: usize,
1018        _eps: f32,
1019        _qk_mode: i32,
1020        _cache_len: usize,
1021        _block_size: usize,
1022        _max_num_blocks_per_seq: usize,
1023    ) -> Result<()> {
1024        Err(FerrumError::unsupported(
1025            "split_qkv_norm_rope_into_paged_cache_vllm not implemented for this backend",
1026        ))
1027    }
1028
1029    /// vLLM-layout variant of
1030    /// [`Self::split_qkv_norm_rope_into_paged_cache_varlen`]. Same signature
1031    /// — only the K/V cache layout changes.
1032    #[allow(clippy::too_many_arguments)]
1033    fn split_qkv_norm_rope_into_paged_cache_varlen_vllm(
1034        _ctx: &mut Self::Context,
1035        _qkv: &Self::Buffer,
1036        _q_norm_w: &Self::Buffer,
1037        _k_norm_w: &Self::Buffer,
1038        _cos: &Self::Buffer,
1039        _sin: &Self::Buffer,
1040        _q_out: &mut Self::Buffer,
1041        _cache_k: &mut Self::Buffer,
1042        _cache_v: &mut Self::Buffer,
1043        _cu_seqlens_q: &Self::Buffer,
1044        _pos_offsets: &Self::Buffer,
1045        _block_tables: &Self::Buffer,
1046        _num_seqs: usize,
1047        _m_total: usize,
1048        _q_heads: usize,
1049        _kv_heads: usize,
1050        _head_dim: usize,
1051        _eps: f32,
1052        _qk_mode: i32,
1053        _block_size: usize,
1054        _max_blocks_per_seq: usize,
1055    ) -> Result<()> {
1056        Err(FerrumError::unsupported(
1057            "split_qkv_norm_rope_into_paged_cache_varlen_vllm not implemented for this backend",
1058        ))
1059    }
1060
1061    /// vLLM `paged_attention_v2` — multi-partition split-K decode attention
1062    /// reading the vLLM K/V layout. `q_len` is implicitly 1 (decode only;
1063    /// vLLM's v2 kernel does not support q_len > 1). `max_seq_len` is the
1064    /// max kv_len across the batch — used to size the partition reduction.
1065    #[allow(clippy::too_many_arguments)]
1066    fn paged_decode_attention_v2(
1067        _ctx: &mut Self::Context,
1068        _q: &Self::Buffer,
1069        _k_pool: &Self::Buffer,
1070        _v_pool: &Self::Buffer,
1071        _out: &mut Self::Buffer,
1072        _block_tables: &Self::Buffer,
1073        _context_lens: &Self::Buffer,
1074        _num_seqs: usize,
1075        _num_heads: usize,
1076        _num_kv_heads: usize,
1077        _head_dim: usize,
1078        _block_size: usize,
1079        _max_num_blocks_per_seq: usize,
1080        _max_seq_len: usize,
1081    ) -> Result<()> {
1082        Err(FerrumError::unsupported(
1083            "paged_decode_attention_v2 not implemented for this backend",
1084        ))
1085    }
1086
1087    /// q_len>1 prefill/chunk-prefill attention over vLLM-layout paged KV.
1088    /// This keeps cache layout consistent when `FERRUM_USE_VLLM_PAGED_ATTN=1`
1089    /// and the prompt path writes K/V in the layout consumed later by
1090    /// `paged_decode_attention_v2`.
1091    #[allow(clippy::too_many_arguments)]
1092    fn paged_varlen_attention_vllm_layout(
1093        _ctx: &mut Self::Context,
1094        _q: &Self::Buffer,
1095        _k_pool: &Self::Buffer,
1096        _v_pool: &Self::Buffer,
1097        _out: &mut Self::Buffer,
1098        _block_tables: &Self::Buffer,
1099        _context_lens: &Self::Buffer,
1100        _num_seqs: usize,
1101        _num_heads: usize,
1102        _num_kv_heads: usize,
1103        _head_dim: usize,
1104        _block_size: usize,
1105        _max_num_blocks_per_seq: usize,
1106        _q_len: usize,
1107    ) -> Result<()> {
1108        Err(FerrumError::unsupported(
1109            "paged_varlen_attention_vllm_layout not implemented for this backend",
1110        ))
1111    }
1112
1113    /// Variable-length paged attention over vLLM-layout paged KV.
1114    ///
1115    /// Unlike [`Self::paged_varlen_attention_vllm_layout`], this accepts the
1116    /// same varlen index tensors as [`Self::paged_varlen_attention`] and writes
1117    /// token-major output directly. It is the unified mixed-batch companion for
1118    /// `split_qkv_norm_rope_into_paged_cache_varlen_vllm`.
1119    #[allow(clippy::too_many_arguments)]
1120    fn paged_varlen_attention_vllm(
1121        _ctx: &mut Self::Context,
1122        _q: &Self::Buffer,
1123        _k_pool: &Self::Buffer,
1124        _v_pool: &Self::Buffer,
1125        _out: &mut Self::Buffer,
1126        _cu_seqlens_q: &Self::Buffer,
1127        _pos_offsets: &Self::Buffer,
1128        _block_tables: &Self::Buffer,
1129        _num_seqs: usize,
1130        _total_q_tokens: usize,
1131        _max_kv_len: usize,
1132        _num_heads: usize,
1133        _num_kv_heads: usize,
1134        _head_dim: usize,
1135        _block_size: usize,
1136        _max_num_blocks_per_seq: usize,
1137    ) -> Result<()> {
1138        Err(FerrumError::unsupported(
1139            "paged_varlen_attention_vllm not implemented for this backend",
1140        ))
1141    }
1142
1143    /// Q-tiled vLLM-layout varlen attention. `tile_seqs` and `tile_starts`
1144    /// describe a compact list of q-token tiles, avoiding empty grid blocks
1145    /// for mixed batches that contain both long prefill items and q_len=1
1146    /// decode items. Semantics match [`Self::paged_varlen_attention_vllm`].
1147    #[allow(clippy::too_many_arguments)]
1148    fn paged_varlen_attention_vllm_tiled_q4(
1149        _ctx: &mut Self::Context,
1150        _q: &Self::Buffer,
1151        _k_pool: &Self::Buffer,
1152        _v_pool: &Self::Buffer,
1153        _out: &mut Self::Buffer,
1154        _cu_seqlens_q: &Self::Buffer,
1155        _pos_offsets: &Self::Buffer,
1156        _block_tables: &Self::Buffer,
1157        _tile_seqs: &Self::Buffer,
1158        _tile_starts: &Self::Buffer,
1159        _num_tiles: usize,
1160        _max_kv_len: usize,
1161        _num_heads: usize,
1162        _num_kv_heads: usize,
1163        _head_dim: usize,
1164        _block_size: usize,
1165        _max_num_blocks_per_seq: usize,
1166    ) -> Result<()> {
1167        Err(FerrumError::unsupported(
1168            "paged_varlen_attention_vllm_tiled_q4 not implemented for this backend",
1169        ))
1170    }
1171}
1172
1173// ════════════════════════════════════════════════════════════════════════
1174// Capability bundles — readable type aliases over the supertrait set
1175// ════════════════════════════════════════════════════════════════════════
1176//
1177// Models declare what they need via these bundles instead of spelling out
1178// every supertrait. Rust auto-derives the impl via blanket impls below,
1179// so any backend that satisfies the underlying supertraits automatically
1180// becomes a `LlmBackend` / `QuantLlmBackend` / `MoeLlmBackend`.
1181
1182/// Minimum capability set for a decoder-only LLM: the core compute trait
1183/// plus paged-KV cache + graph-capture support. Every concrete backend
1184/// (CUDA / Metal / CPU) satisfies this.
1185pub trait LlmBackend: Backend + BackendGraph + BackendPagedKv {}
1186impl<T> LlmBackend for T where T: Backend + BackendGraph + BackendPagedKv {}
1187
1188/// LLM backend that also supports quantized weight loading (GPTQ Marlin
1189/// for CUDA; GGUF k-quant for Metal). Required by models that hold
1190/// `Box<dyn Linear<B>>` where the Linear impl might be a quant variant.
1191pub trait QuantLlmBackend: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
1192impl<T> QuantLlmBackend for T where T: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
1193
1194/// MoE-capable LLM backend: adds the fused MoE routing + post-op kernels
1195/// to the quant LLM bundle. Required by Qwen3-MoE / future MoE models.
1196pub trait MoeLlmBackend: QuantLlmBackend + BackendMoeFused {}
1197impl<T> MoeLlmBackend for T where T: QuantLlmBackend + BackendMoeFused {}
1198
1199// ════════════════════════════════════════════════════════════════════════
1200// KV cache dtype axis (dim 5 of the 5-dimension architecture)
1201// ════════════════════════════════════════════════════════════════════════
1202//
1203// Each model's KV cache has its own precision independent of the model's
1204// compute precision. vLLM 0.6+ ships INT8 / FP8 KV caches that halve KV
1205// memory at small (<1%) accuracy hit. Today ferrum's KV is hardcoded
1206// FP16 on CUDA / Metal — to support INT8/FP8 KV in a future PR, the
1207// type system needs an explicit axis.
1208//
1209// Phase 4 scope: scaffolding only. All concrete backends impl
1210// `BackendKvDtype<KvFp16>` so existing models keep working unchanged.
1211// Future PR: implement BackendKvDtype<KvInt8> on CUDA + a new model
1212// type-parameter `K: KvDtypeKind` to wire it through.
1213
1214// `KvDtypeKind` + `KvFp16` / `KvBf16` / `KvInt8` / `KvFp8` markers moved
1215// to `ferrum_interfaces::kv_dtype` (no GPU deps, so the right place is
1216// the contract crate). Re-exported here so existing callers keep
1217// compiling against `crate::backend::KvFp16` etc.
1218pub use ferrum_interfaces::kv_dtype::{KvBf16, KvDtypeKind, KvFp16, KvFp8, KvInt8};
1219
1220/// Capability-trait for backends that can store + read a KV cache of
1221/// type `K`.
1222///
1223/// The two associated types carry the K-specific storage shape:
1224///   - `KvBuffer`: per-layer K/V element storage. For `K = KvFp16` it
1225///     is the backend's normal `Self::Buffer` (FP16). For `K = KvInt8`
1226///     it is the backend's INT8 buffer (e.g. `CudaSlice<i8>` on CUDA).
1227///   - `KvScales`: per-token-per-kv-head scales. For `K = KvFp16` this
1228///     is the unit type `()` (no scales). For `K = KvInt8` / `KvFp8`
1229///     it is a backend-specific FP16 buffer.
1230///
1231/// Models that want INT8 KV use:
1232///   `where B: BackendKvDtype<KvInt8>`
1233/// — the buffers in `KvCache<B, KvInt8>` are then `CudaSlice<i8>` and
1234/// `CudaSlice<f16>`, distinct from the FP16 path's `Self::Buffer`.
1235pub trait BackendKvDtype<K: KvDtypeKind>: BackendPagedKv {
1236    /// Per-layer K/V element storage.
1237    type KvBuffer: Send + Sync;
1238    /// Per-token per-kv-head scale storage. `()` for FP16 (no scales).
1239    type KvScales: Send + Sync + Default;
1240}
1241
1242/// INT8 KV cache operations (Dim 5).
1243///
1244/// `BackendKvDtype<KvInt8>` only declares the storage types; it does not
1245/// know how to write INT8 K/V into a paged pool or run paged decode
1246/// attention against an INT8 cache. Those launchers live here so the
1247/// model layer can call them through a single `B: BackendInt8KvOps` bound
1248/// without dropping into backend-specific code.
1249///
1250/// Today only `CudaBackend` provides a real implementation (delegating to
1251/// [`crate::int8_kv::launch_int8_kv_cache_append`] and
1252/// [`crate::int8_kv::launch_int8_paged_decode_attention`]). Other backends
1253/// inherit the default `unimplemented!()` body — the registry factory
1254/// rejects `(Device::CPU/Metal, KvCacheDtype::Int8)` before the model
1255/// gets a chance to call into these.
1256#[allow(clippy::too_many_arguments)]
1257pub trait BackendInt8KvOps: Backend + BackendKvDtype<KvInt8> {
1258    /// Allocate the per-layer INT8 paged cache for one sequence.
1259    /// Default panics — backends without INT8 support never reach this
1260    /// path (factory rejects (Cpu/Metal, Int8) before ensure_kv runs).
1261    fn alloc_paged_int8_layer(
1262        _max_blocks_per_seq: usize,
1263        _block_size: usize,
1264        _num_kv_heads: usize,
1265        _head_dim: usize,
1266    ) -> KvCacheQuant<Self, KvInt8> {
1267        unimplemented!("alloc_paged_int8_layer not supported on this backend")
1268    }
1269
1270    /// Append `tokens` FP16 K/V values into the paged INT8 pool.
1271    /// `paged_block_indices` is the host-side mirror of the per-seq
1272    /// logical→physical block table (already populated at `ensure_kv` time
1273    /// — see `KvCacheQuant::paged_block_indices`). Passing the host slice
1274    /// avoids a per-token D2H + sync barrier; backend computes the slot
1275    /// mapping host-side, async-H2D's it, and chains the append kernel
1276    /// on the same stream — fully overlapping with prior work.
1277    /// `cache_len_before` is the current number of valid tokens; the
1278    /// backend quantizes FP16 → INT8 with per-(token, kv-head) FP16 scale
1279    /// and writes both into the layer's INT8 / scale buffers.
1280    fn int8_kv_append_paged(
1281        _ctx: &mut Self::Context,
1282        _k_in: &Self::Buffer,
1283        _v_in: &Self::Buffer,
1284        _layer_k: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
1285        _layer_v: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
1286        _layer_k_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
1287        _layer_v_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
1288        _paged_block_indices: &[u32],
1289        _cache_len_before: usize,
1290        _tokens: usize,
1291        _block_size: usize,
1292        _num_kv_heads: usize,
1293        _head_dim: usize,
1294    ) -> Result<()> {
1295        Err(FerrumError::unsupported(
1296            "int8_kv_append_paged not implemented for this backend",
1297        ))
1298    }
1299
1300    /// Run paged decode attention reading from an INT8 cache. Q is FP16,
1301    /// output is FP16; the kernel dequantizes K/V on the fly using the
1302    /// per-token scales. `valid_kv_len` is the post-append cache length
1303    /// (i.e. the kernel attends over `[0, valid_kv_len)` tokens).
1304    fn int8_paged_decode_attention(
1305        _ctx: &mut Self::Context,
1306        _q: &Self::Buffer,
1307        _layer_k: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
1308        _layer_v: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
1309        _layer_k_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
1310        _layer_v_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
1311        _block_table: &Self::Buffer,
1312        _output: &mut Self::Buffer,
1313        _num_q_heads: usize,
1314        _num_kv_heads: usize,
1315        _head_dim: usize,
1316        _valid_kv_len: usize,
1317        _block_size: usize,
1318        _scale: f32,
1319    ) -> Result<()> {
1320        Err(FerrumError::unsupported(
1321            "int8_paged_decode_attention not implemented for this backend",
1322        ))
1323    }
1324}
1325
1326// Cpu/Metal NOT impl `BackendInt8KvOps` — the trait pivot to
1327// `KvLayer<B>` means `KvInt8: KvLayer<B>` only holds where
1328// `B: BackendInt8KvOps`, so `LlamaFamilyModel<CpuBackend, KvInt8>` is a
1329// compile error (no INT8 KvLayer impl satisfies it). Type system
1330// enforces the constraint without runtime stubs.