Skip to main content

ferrum_kernels/backend/
traits.rs

1//! Core Backend trait — the single abstraction over CUDA / Metal / CPU.
2
3use ferrum_types::{FerrumError, Result};
4
5pub use super::capabilities::{
6    BackendCollective, BackendGraph, BackendMoeFused, BackendQuantGguf, BackendQuantMarlin,
7};
8pub use super::types::MoeRouting;
9use super::types::{AttnConfig, KvCacheQuant, SrcDtype};
10
11/// Maximum decode-graph layer count. Per-layer call sites that share
12/// graph-captured host staging arrays use this as the stride between
13/// distinct slots. CUDA-only invariant (other backends ignore the
14/// `slot` argument); 64 covers all current LLM families up to and
15/// including Llama-3-70B (80 layers — but 70B doesn't run on a single
16/// 4090 anyway, so 64 is safe in practice for v0.2).
17pub const MAX_LAYERS_FOR_GRAPH: usize = 64;
18
19// Note: `TransformerConfig` / `AttnType` / `MlpType` / `RopeConfig` used to
20// live here when `ModelRunner` needed a generic model config. They're now
21// per-model (e.g. `Qwen3Config` in `ferrum-models::models::qwen3`) so each
22// model can carry exactly the architecture parameters it cares about.
23// Backend trait stays model-agnostic.
24
25/// The core abstraction over CUDA / Metal / CPU.
26///
27/// Key design: operations take a `&mut Self::Context` which accumulates work.
28///   - **CPU**: Context is `()` — ops execute immediately.
29///   - **Metal**: Context is a `CommandBuffer` — ops encode into it, flushed on `sync()`.
30///   - **CUDA**: Context is a `CudaStream` — ops launch on the stream, synced on `sync()`.
31///
32/// `layer_forward` passes the context through all ops in a layer.
33/// `ModelRunner` calls `sync()` only when it needs results (e.g., reading logits).
34pub trait Backend: Send + Sync + Sized + 'static {
35    type Buffer: Send + Sync;
36
37    /// Execution context that accumulates GPU work.
38    ///   - CPU: `()` (no-op, ops execute inline)
39    ///   - Metal: wraps a CommandBuffer
40    ///   - CUDA: wraps a CudaStream
41    type Context;
42
43    /// GPU-side timer scoped to this backend. See `super::timer` —
44    /// CPU: `Instant`; Metal: sync-wrap; CUDA: `cuEvent`.
45    /// PLAYBOOK § 1.1.
46    type Timer: super::timer::BackendTimer<Self>;
47
48    /// Factory for `Self::Timer` — exists so call sites that have a
49    /// `<B: Backend>` parameter can spawn a timer without importing the
50    /// concrete impl. PLAYBOOK § 1.2.
51    fn make_timer() -> Self::Timer;
52
53    /// Opaque per-backend GPTQ weight representation.
54    ///   - CPU: dequantized f32 weights (run as regular GEMM)
55    ///   - Metal: `()` — unsupported; `gemm_gptq` errors
56    // Note (Phase 3e/4 + Phase C):
57    // - `type QuantStore` (GGUF k-quant storage) was removed in Phase 3e/4
58    //   — stacked-expert MoE GGUF goes through Box<dyn StackedExpertGgufLinear<Self>>
59    //   returned by `load_quant_experts`.
60    // - `type GptqStore` (Marlin/dequant GPTQ storage) was removed in Phase C
61    //   step 4e — stacked-expert Marlin MoE goes through
62    //   Arc<dyn MarlinExpertStack<Self>> returned by `load_gptq_stacked`,
63    //   and single-tensor GPTQ goes through Box<dyn Linear<Self>> returned
64    //   by `load_gptq`. Adding a new Marlin-capable backend is purely a
65    //   new MarlinExpertStack<NewBackend> impl — no Backend trait edits.
66
67    /// Create a new execution context (begin accumulating work).
68    fn new_context() -> Self::Context;
69
70    /// Run `body` while binding context-free backend operations to an
71    /// explicit device ordinal when the backend supports multi-device scopes.
72    ///
73    /// Most backends have no per-ordinal concept and use the default no-op
74    /// implementation. CUDA overrides this once its stream/context caches are
75    /// device-keyed, allowing layer-split stages to load and execute on their
76    /// selected GPU instead of relying on process-global defaults.
77    fn with_device_ordinal<R>(_device_ordinal: Option<usize>, body: impl FnOnce() -> R) -> R {
78        body()
79    }
80
81    /// Whether [`Self::with_device_ordinal`] actually switches backend
82    /// execution to the requested ordinal.
83    fn supports_device_ordinal_scope() -> bool {
84        false
85    }
86
87    /// Flush accumulated work and wait for completion.
88    /// CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
89    fn sync(ctx: &mut Self::Context);
90
91    /// Prepare pending GPU work for a following host readback.
92    ///
93    /// Most backends either execute eagerly or synchronize as part of their
94    /// device-to-host copy. Metal shared-buffer reads use the CPU pointer
95    /// directly, so Metal must flush its command buffer before `to_vec`.
96    fn sync_before_host_readback(_ctx: &mut Self::Context) {}
97
98    /// Byte width of buffers returned by [`Self::alloc`].
99    ///
100    /// CUDA activation scratch is fp16, while Metal and CPU scratch are fp32.
101    /// Generic model code uses this for byte offsets into batched scratch
102    /// buffers without checking concrete backend types.
103    fn activation_elem_size_bytes() -> usize {
104        std::mem::size_of::<half::f16>()
105    }
106
107    /// Whether `LlamaFamilyModel::decode_batch_internal` may use its optimized
108    /// batched decode path on this backend.
109    ///
110    /// Backends that do not yet produce correct follow-up logits under
111    /// concurrent dense decode should override this to force the per-item
112    /// fallback until the optimized path is fixed.
113    fn supports_llama_family_batched_decode() -> bool {
114        true
115    }
116
117    // Graph capability moved to the `BackendGraph` supertrait at the end
118    // of this file. CUDA implements its overrides; Metal/CPU inherit
119    // unsupported defaults via empty `impl BackendGraph for X {}` blocks.
120
121    // ── GPTQ (INT4 quantization) ────────────────────────────────────────
122    //
123    // Two-step: load (once per weight) → gemm (per forward). The store
124    // holds whatever backend-specific format is fastest; caller code
125    // (GptqLinear) is dtype-agnostic.
126
127    /// Zero the first `len` elements of a Self::Buffer. CUDA path uses
128    /// cuMemsetD16Async; default returns unsupported.
129    fn zero_buffer(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize) -> Result<()> {
130        Err(FerrumError::unsupported(
131            "zero_buffer not implemented for this backend",
132        ))
133    }
134
135    /// Phase D step 2+3: unified typed allocator. Replaces per-dtype
136    /// `alloc_u32` / `alloc_typed_i32` / etc. The buffer is dtype-
137    /// tagged at the wrapper level (`CudaBuf::U32`, `MetalBuf` with
138    /// `Dtype::U32`, `CpuBuf::U32`), so reads/writes through `.as_<T>()`
139    /// accessors get the correct byte count automatically.
140    fn alloc_typed(dtype: super::Dtype, n: usize) -> Self::Buffer;
141
142    /// Upload typed host data — replaces `from_slice_i32` /
143    /// `from_slice_u32` etc. The host element type `T` carries its
144    /// `Dtype` via the `HostDtype` marker so dispatch in the impl
145    /// is a one-line `match T::DTYPE`.
146    fn from_slice_typed<T: super::HostDtype>(data: &[T]) -> Self::Buffer;
147
148    /// In-place typed write — replaces `write_u32` / `write_i32_into`
149    /// / `write_f32_into`. The buffer must already be dtype-tagged
150    /// matching `T::DTYPE` (typically alloc'd via `alloc_typed` or
151    /// `from_slice_typed`).
152    fn write_typed<T: super::HostDtype>(
153        ctx: &mut Self::Context,
154        dst: &mut Self::Buffer,
155        data: &[T],
156    );
157
158    // ── GEMM ────────────────────────────────────────────────────────────
159
160    fn gemm(
161        ctx: &mut Self::Context,
162        a: &Self::Buffer,
163        b: &Self::Buffer,
164        out: &mut Self::Buffer,
165        m: usize,
166        n: usize,
167        k: usize,
168    );
169
170    // ── Norms ───────────────────────────────────────────────────────────
171
172    fn rms_norm(
173        ctx: &mut Self::Context,
174        x: &Self::Buffer,
175        w: &Self::Buffer,
176        eps: f32,
177        out: &mut Self::Buffer,
178        tokens: usize,
179        dim: usize,
180    );
181
182    fn fused_add_rms_norm(
183        ctx: &mut Self::Context,
184        residual: &mut Self::Buffer,
185        x: &Self::Buffer,
186        w: &Self::Buffer,
187        eps: f32,
188        out: &mut Self::Buffer,
189        tokens: usize,
190        dim: usize,
191    );
192
193    // ── Attention ───────────────────────────────────────────────────────
194
195    fn flash_attention(
196        ctx: &mut Self::Context,
197        q: &Self::Buffer,
198        k: &Self::Buffer,
199        v: &Self::Buffer,
200        out: &mut Self::Buffer,
201        batch: usize,
202        q_len: usize,
203        kv_len: usize,
204        pos_offset: usize,
205        cfg: &AttnConfig,
206    );
207
208    /// Multi-Head Latent Attention — DeepSeek V2 / V3's compressed-KV
209    /// attention variant. Extension point only; no backend implements it
210    /// yet. DeepSeek V3 landing in Phase D/E will fill this in.
211    ///
212    /// `q`: full Q `[batch, num_heads, q_len, head_dim]`
213    /// `kv_compressed`: latent KV `[batch, kv_len, kv_lora_rank]`
214    /// `kv_rope`: per-position rope-applied key heads `[batch, kv_len, qk_rope_head_dim]`
215    /// `out`: `[batch, num_heads, q_len, head_dim]`
216    #[allow(clippy::too_many_arguments)]
217    fn mla_attention(
218        _ctx: &mut Self::Context,
219        _q: &Self::Buffer,
220        _kv_compressed: &Self::Buffer,
221        _kv_rope: &Self::Buffer,
222        _out: &mut Self::Buffer,
223        _batch: usize,
224        _q_len: usize,
225        _kv_len: usize,
226        _pos_offset: usize,
227        _cfg: &AttnConfig,
228        _kv_lora_rank: usize,
229        _qk_rope_head_dim: usize,
230    ) -> Result<()> {
231        Err(FerrumError::unsupported(
232            "mla_attention not implemented for this backend; required by \
233             DeepSeek V2/V3 (Phase D/E)",
234        ))
235    }
236
237    // ── Element-wise ────────────────────────────────────────────────────
238    //
239    // Models use `add_inplace` for residual updates and `copy_slice` for the
240    // row-extraction step in prefill. Offset-free copy / non-inplace add are
241    // not needed by the current Model-as-Code path; they can return later if
242    // a model actually requires them.
243
244    /// Copy `len` floats from `src[src_offset..]` to `dst[dst_offset..]`.
245    ///
246    /// Needed for Qwen3Model::prefill to pluck the last token's hidden state
247    /// out of `residual[seq_len, h]` without round-tripping through host RAM.
248    /// `Backend::copy` is the offset-free variant; `copy_slice` additionally
249    /// supports non-zero source and destination offsets.
250    fn copy_slice(
251        ctx: &mut Self::Context,
252        src: &Self::Buffer,
253        src_offset: usize,
254        dst: &mut Self::Buffer,
255        dst_offset: usize,
256        len: usize,
257    );
258
259    // ── Embedding ───────────────────────────────────────────────────────
260
261    fn embedding_lookup(
262        ctx: &mut Self::Context,
263        table: &Self::Buffer,
264        ids: &[u32],
265        out: &mut Self::Buffer,
266        dim: usize,
267    );
268
269    /// Device-buffer variant of `embedding_lookup` for graph-capturable
270    /// MoE routing — the gather step before phase-1 GEMM in
271    /// `moe_forward_bucketed`. The host-slice `embedding_lookup` does
272    /// `clone_htod(ids)` internally, which records stale host pointers
273    /// under CUDA Graph capture replay.
274    ///
275    /// `ids: &Self::Buffer` must be a device I32 buffer of `batch`
276    /// elements (e.g. `Qwen3MoeScratch::route_packed_idx_dev`).
277    /// `batch` is passed explicitly since a typed CudaBuf carries
278    /// its element count but the caller often wants a partial gather.
279    ///
280    /// Default impl: round-trip via `to_vec` + dispatch the host-slice
281    /// variant. CUDA overrides.
282    fn embedding_lookup_dev(
283        ctx: &mut Self::Context,
284        table: &Self::Buffer,
285        ids: &Self::Buffer,
286        out: &mut Self::Buffer,
287        batch: usize,
288        dim: usize,
289    ) {
290        // Default: round-trip. CUDA overrides with a direct device-arg
291        // kernel launch (no clone_htod).
292        let ids_host_f32 = Self::to_vec(ids, batch);
293        let ids_host_u32: Vec<u32> = ids_host_f32.iter().map(|x| x.to_bits()).collect();
294        Self::embedding_lookup(ctx, table, &ids_host_u32, out, dim);
295    }
296
297    // ── Transformer-specific fused ops ─────────────────────────────────
298    // These avoid CPU round-trips for data layout transformations.
299
300    /// Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers.
301    /// Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
302    fn split_qkv(
303        ctx: &mut Self::Context,
304        qkv: &Self::Buffer,
305        q: &mut Self::Buffer,
306        k: &mut Self::Buffer,
307        v: &mut Self::Buffer,
308        tokens: usize,
309        q_dim: usize,
310        kv_dim: usize,
311    );
312
313    /// Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im],
314    /// then compute SiLU(gate) * up → out [tokens, im].
315    fn fused_silu_mul_split(
316        ctx: &mut Self::Context,
317        gate_up: &Self::Buffer,
318        out: &mut Self::Buffer,
319        tokens: usize,
320        im: usize,
321    );
322
323    /// Fused QK-norm + RoPE + transpose-to-head-major.
324    ///
325    /// `mode` selects the operation:
326    ///   0 = transpose only (typical for V, which needs no norm and no RoPE)
327    ///   1 = per-head RMS norm + RoPE + transpose  (Q/K with QK-norm, Qwen3)
328    ///   2 = RoPE + transpose                       (Q/K without QK-norm, Llama/Mistral)
329    ///
330    /// input:   `[tokens, heads, head_dim]`  (token-major, output of split_qkv)
331    /// output:  `[heads, tokens, head_dim]`  (head-major, ready for flash_attn / kv_cache_append)
332    ///
333    /// `pos_offset` is the position of token 0 (decode uses current seq len;
334    /// prefill uses 0). Within the batch, positions are taken as `pos_offset + i`.
335    ///
336    /// This is the primary attention-input preparation op. Backends that have a
337    /// fused kernel (Metal's `qk_norm_rope_transpose_f32`) will be dramatically
338    /// faster than composing norm + rope + transpose separately; the CPU
339    /// fallback lowers to the individual ops.
340    #[allow(clippy::too_many_arguments)]
341    fn qk_norm_rope(
342        ctx: &mut Self::Context,
343        input: &Self::Buffer,
344        norm_w: &Self::Buffer,
345        cos: &Self::Buffer,
346        sin: &Self::Buffer,
347        output: &mut Self::Buffer,
348        tokens: usize,
349        heads: usize,
350        head_dim: usize,
351        pos_offset: usize,
352        eps: f32,
353        mode: i32,
354    );
355
356    /// Batched kv_cache_append across M caches in one launch. Each item
357    /// writes its (head-major) K-or-V row into its own cache at offset
358    /// read from `cache_lens[i]`. Replaces M sequential
359    /// `kv_cache_append_head_major` calls with a single dispatch.
360    ///
361    /// `new_data` layout: `[m, nkv, hd]` item-major (each item's slice
362    /// is contiguous, identical to the `k/v_normed_batched` produced by
363    /// `qk_norm_rope_batched_per_item`).
364    /// `caches`: per-cache `[nkv, capacity, hd]` head-major.
365    /// `cache_lens`: device buffer (u32 storage, length ≥ m). Caller
366    /// fills via `B::write_u32_into` BEFORE the call. Required for
367    /// CUDA-graph capture: the kernel reads from this stable device
368    /// buffer, so a captured graph can be replayed with new lens by
369    /// just rewriting the buffer between launches.
370    fn kv_cache_append_batched_per_cache(
371        _ctx: &mut Self::Context,
372        _caches: &[&Self::Buffer],
373        _new_data: &Self::Buffer,
374        _cache_lens: &Self::Buffer,
375        _capacity: usize,
376        _m: usize,
377        _nkv: usize,
378        _hd: usize,
379        _slot: usize,
380    ) -> Result<()> {
381        Err(FerrumError::unsupported(
382            "kv_cache_append_batched_per_cache not implemented for this backend",
383        ))
384    }
385
386    /// Batched flash_attention across M decode caches in one launch.
387    /// Replaces the per-item `flash_attention(q_len=1, ...)` × M
388    /// loop in the non-paged batched-decode path.
389    ///
390    /// API takes Vec<&Buffer> for the per-cache K/V buffers (each
391    /// `[nkv, capacity, hd]` head-major) plus host-side `kv_lens`.
392    /// Backends that implement it must extract per-cache device
393    /// pointers, build the device arrays the kernel needs, and launch
394    /// one kernel covering all M items.
395    ///
396    /// `q` layout: [m, nq, hd] item-major (matches the
397    /// `qk_norm_rope_batched_per_item` output for q_len=1).
398    /// `out` layout: [m, nq, hd] item-major — written directly into
399    /// the caller's batched attn_out buffer, no per-item copy needed.
400    ///
401    /// CUDA-only for now (kernel `batched_decode_attention` exists in
402    /// `kernels/batched_decode_attention.cu`).
403    /// `kv_lens`: device buffer (u32 storage, length ≥ m) — same
404    /// design as `kv_cache_append_batched_per_cache::cache_lens`.
405    fn flash_attention_batched_per_cache(
406        _ctx: &mut Self::Context,
407        _q: &Self::Buffer,
408        _k_caches: &[&Self::Buffer],
409        _v_caches: &[&Self::Buffer],
410        _kv_lens: &Self::Buffer,
411        _out: &mut Self::Buffer,
412        _nq: usize,
413        _nkv: usize,
414        _hd: usize,
415        _scale: f32,
416        _max_valid_kv: usize,
417        _capacity: usize,
418        _slot: usize,
419    ) -> Result<()> {
420        Err(FerrumError::unsupported(
421            "flash_attention_batched_per_cache not implemented for this backend",
422        ))
423    }
424
425    /// Batched per-item-position variant of `qk_norm_rope` for the
426    /// non-paged batched-decode path. Each of the `m` items has its own
427    /// absolute RoPE position (read from a device i32 buffer of length
428    /// `m`). Layout is item-major in *both* input and output:
429    ///
430    ///   input  [m, heads, head_dim]
431    ///   output [m, heads, head_dim]   (no head-major transpose)
432    ///
433    /// Item-major output keeps the per-item flash_attention slice
434    /// contiguous (`output[i * heads * head_dim ..]` is item i's whole
435    /// Q tensor in head-major-equivalent layout for q_len=1).
436    ///
437    /// Replaces the M sequential single-item launches in the existing
438    /// `forward_layer_batched_decode` path with one batched dispatch.
439    /// CUDA-only for now; other backends fall through to the default
440    /// `unsupported` and the caller falls back to the per-item loop.
441    fn qk_norm_rope_batched_per_item(
442        _ctx: &mut Self::Context,
443        _input: &Self::Buffer,
444        _norm_w: &Self::Buffer,
445        _cos: &Self::Buffer,
446        _sin: &Self::Buffer,
447        _output: &mut Self::Buffer,
448        _positions: &Self::Buffer,
449        _m: usize,
450        _heads: usize,
451        _head_dim: usize,
452        _eps: f32,
453        _mode: i32,
454    ) -> Result<()> {
455        Err(FerrumError::unsupported(
456            "qk_norm_rope_batched_per_item not implemented for this backend",
457        ))
458    }
459
460    /// Fused split-QKV + QK-norm + RoPE + head-major transpose.
461    ///
462    /// Single-dispatch replacement for the (`split_qkv` → 3× `qk_norm_rope`)
463    /// chain on the decode-attention prelude. Reads the linear-layer
464    /// fused-QKV output once and writes head-major Q/K/V directly into
465    /// attention scratch.
466    ///
467    /// `qkv` layout: `[tokens, q_heads*hd + 2*kv_heads*hd]`.
468    /// `q_out`: `[q_heads, tokens, hd]`. `k_out`/`v_out`: `[kv_heads, tokens, hd]`.
469    /// `qk_mode`: 1 = norm + half-split RoPE for Q/K (Qwen3 with QK-norm),
470    ///            2 = half-split RoPE only for Q/K,
471    ///            3 = interleaved RoPE only for Q/K (GGUF LLaMA / llama.cpp layout).
472    /// V always falls through to transpose-only.
473    ///
474    /// Default returns Unsupported. Backends that implement it are
475    /// expected to be dramatically faster than the four-dispatch chain.
476    #[allow(clippy::too_many_arguments)]
477    fn split_qkv_norm_rope(
478        _ctx: &mut Self::Context,
479        _qkv: &Self::Buffer,
480        _q_norm_w: &Self::Buffer,
481        _k_norm_w: &Self::Buffer,
482        _cos: &Self::Buffer,
483        _sin: &Self::Buffer,
484        _q_out: &mut Self::Buffer,
485        _k_out: &mut Self::Buffer,
486        _v_out: &mut Self::Buffer,
487        _tokens: usize,
488        _q_heads: usize,
489        _kv_heads: usize,
490        _head_dim: usize,
491        _pos_offset: usize,
492        _eps: f32,
493        _qk_mode: i32,
494    ) -> Result<()> {
495        Err(FerrumError::unsupported(
496            "split_qkv_norm_rope not implemented for this backend",
497        ))
498    }
499
500    /// Variant of [`Backend::split_qkv_norm_rope`] that writes the new
501    /// K and V directly into pre-allocated head-major KV cache buffers
502    /// at slot `[kv_heads, cache_len .. cache_len + tokens, hd]`.
503    /// Eliminates the trailing `kv_cache_append_head_major` dispatch on
504    /// the decode hot path. Q still lands in per-token head-major
505    /// scratch (flash-attention reads it as the query).
506    ///
507    /// Default returns Unsupported. Backends without the fused kernel
508    /// can keep using `split_qkv_norm_rope` + `kv_cache_append_head_major`.
509    #[allow(clippy::too_many_arguments)]
510    fn split_qkv_norm_rope_into_cache(
511        _ctx: &mut Self::Context,
512        _qkv: &Self::Buffer,
513        _q_norm_w: &Self::Buffer,
514        _k_norm_w: &Self::Buffer,
515        _cos: &Self::Buffer,
516        _sin: &Self::Buffer,
517        _q_out: &mut Self::Buffer,
518        _cache_k: &mut Self::Buffer,
519        _cache_v: &mut Self::Buffer,
520        _tokens: usize,
521        _q_heads: usize,
522        _kv_heads: usize,
523        _head_dim: usize,
524        _pos_offset: usize,
525        _eps: f32,
526        _qk_mode: i32,
527        _cache_len: usize,
528        _cache_capacity: usize,
529    ) -> Result<()> {
530        Err(FerrumError::unsupported(
531            "split_qkv_norm_rope_into_cache not implemented for this backend",
532        ))
533    }
534
535    // Phase D step 2: alloc_u32 / write_u32 deleted. Callers use the
536    // unified `alloc_typed(Dtype::U32, n)` + `write_typed(&[u32])` API
537    // declared above.
538
539    /// Append new K/V into a pre-allocated head-major cache buffer.
540    ///
541    /// `cache_k` / `cache_v`: `[nkv, capacity, hd]` (head-major, pre-allocated)
542    /// `new_k_head_major` / `new_v_head_major`: `[nkv, new_tokens, hd]`
543    ///   — produced directly by `qk_norm_rope`, no extra transpose needed.
544    ///
545    /// In-place append at slot `[nkv, cache_len..cache_len+new_tokens, hd]`.
546    /// Caller owns `cache_len` bookkeeping.
547    #[allow(clippy::too_many_arguments)]
548    fn kv_cache_append_head_major(
549        ctx: &mut Self::Context,
550        cache_k: &mut Self::Buffer,
551        cache_v: &mut Self::Buffer,
552        cache_len: usize,
553        cache_capacity: usize,
554        new_k_head_major: &Self::Buffer,
555        new_v_head_major: &Self::Buffer,
556        new_tokens: usize,
557        nkv: usize,
558        hd: usize,
559    );
560
561    /// Transpose [heads, tokens, dim] → [tokens, heads, dim].
562    /// Called after `flash_attention` to restore token-major layout for O-proj.
563    fn transpose_head_to_token(
564        ctx: &mut Self::Context,
565        src: &Self::Buffer,
566        dst: &mut Self::Buffer,
567        tokens: usize,
568        heads: usize,
569        dim: usize,
570    );
571
572    /// Inverse of `transpose_head_to_token`: [tokens, heads, dim] →
573    /// [heads, tokens, dim]. Used by the CUDA `paged_decode_attention`
574    /// wrapper to convert `paged_varlen_attention`'s token-major output
575    /// back to the head-major layout that Qwen3MoeModel expects.
576    /// Default panics — backends without a paged-KV CUDA path don't
577    /// hit this code.
578    fn transpose_token_to_head(
579        _ctx: &mut Self::Context,
580        _src: &Self::Buffer,
581        _dst: &mut Self::Buffer,
582        _tokens: usize,
583        _heads: usize,
584        _dim: usize,
585    ) {
586        panic!("transpose_token_to_head not implemented for this backend");
587    }
588
589    /// residual[i] += x[i] (in-place)
590    fn add_inplace(
591        ctx: &mut Self::Context,
592        residual: &mut Self::Buffer,
593        x: &Self::Buffer,
594        len: usize,
595    );
596
597    /// `dst[i] += scale * src[i]` — scalar-broadcast scaled add, in place.
598    ///
599    /// MoE per-token combine writes `out[b] += weight_k * expert_k(x[b])`
600    /// for each top-K expert; this primitive is the per-call accumulate.
601    /// Backends without a dedicated kernel can fall back to the default
602    /// implementation, which round-trips through host memory — correct,
603    /// but slow on a hot path. Override on any backend you actually
604    /// dispatch MoE on.
605    fn scaled_add_inplace(
606        _ctx: &mut Self::Context,
607        dst: &mut Self::Buffer,
608        src: &Self::Buffer,
609        scale: f32,
610        len: usize,
611    ) {
612        let mut dst_v = Self::to_vec(dst, len);
613        let src_v = Self::to_vec(src, len);
614        for i in 0..len {
615            dst_v[i] += scale * src_v[i];
616        }
617        // Move the new buffer into the slot pointed to by `dst`. Safe
618        // because `Self::Buffer: Send + Sync` and the old buffer is
619        // dropped here when overwritten.
620        *dst = Self::from_slice(&dst_v);
621    }
622
623    /// Strided variant of [`Backend::fused_silu_mul_split`] for the
624    /// bucketed MoE path: reads `gate_up` rows starting at
625    /// `in_row_offset`, writes `out` rows starting at `out_row_offset`.
626    #[allow(clippy::too_many_arguments)]
627    fn fused_silu_mul_split_strided(
628        _ctx: &mut Self::Context,
629        _gate_up: &Self::Buffer,
630        _in_row_offset: usize,
631        _out: &mut Self::Buffer,
632        _out_row_offset: usize,
633        _tokens: usize,
634        _intermediate: usize,
635    ) {
636        unimplemented!("fused_silu_mul_split_strided default impl missing");
637    }
638
639    /// Broadcast bias add: `data[r, c] += bias[c]` for every row.
640    /// Required by Bert / Clip / Whisper whose linear projections carry a bias.
641    fn add_bias(
642        ctx: &mut Self::Context,
643        data: &mut Self::Buffer,
644        bias: &Self::Buffer,
645        rows: usize,
646        cols: usize,
647    );
648
649    /// Full LayerNorm (mean + variance normalisation + affine), distinct from
650    /// the `rms_norm` used by Llama-family decoders.
651    ///   `out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]`
652    /// Where `mean` and `var` are reduced over the last dim (cols).
653    #[allow(clippy::too_many_arguments)]
654    fn layer_norm(
655        ctx: &mut Self::Context,
656        x: &Self::Buffer,
657        gamma: &Self::Buffer,
658        beta: &Self::Buffer,
659        eps: f32,
660        out: &mut Self::Buffer,
661        tokens: usize,
662        dim: usize,
663    );
664
665    /// Element-wise GELU activation (erf-based, matches PyTorch default).
666    fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
667
668    // ── Buffer management (context-free) ────────────────────────────────
669
670    fn alloc(len: usize) -> Self::Buffer;
671    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
672    fn from_slice(data: &[f32]) -> Self::Buffer;
673
674    fn write_f32_to_activation(ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[f32]) {
675        if data.is_empty() {
676            return;
677        }
678        let src = Self::from_slice(data);
679        Self::copy_slice(ctx, &src, 0, dst, 0, data.len());
680    }
681
682    /// Greedy-decode fast path: GPU argmax over each row of a
683    /// `[m, n]` FP16 logits buffer, returning the m token indices on the
684    /// host. Saves `m × n × 2` bytes of D2H per call (e.g. 19.5 MB at
685    /// c=32, vocab=152064) and the host-side argmax scan (~150 µs × m).
686    ///
687    /// Default impl falls back to the slow path: full `to_vec` + host
688    /// argmax. CUDA overrides with a native kernel + tiny D2H (m × 4 B).
689    /// Backends that don't override pay the same cost as
690    /// `to_vec` + host argmax, so callers can call this unconditionally.
691    fn argmax_rows_f16(
692        _ctx: &mut Self::Context,
693        logits: &Self::Buffer,
694        m: usize,
695        n: usize,
696    ) -> Result<Vec<u32>> {
697        let host = Self::to_vec(logits, m * n);
698        let mut out = Vec::with_capacity(m);
699        for row in 0..m {
700            let slice = &host[row * n..(row + 1) * n];
701            let mut max_idx = 0usize;
702            let mut max_val = f32::NEG_INFINITY;
703            for (i, &v) in slice.iter().enumerate() {
704                if v > max_val {
705                    max_val = v;
706                    max_idx = i;
707                }
708            }
709            out.push(max_idx as u32);
710        }
711        Ok(out)
712    }
713
714    /// Load a weight tensor straight from its on-disk byte representation,
715    /// letting the backend pick its preferred storage dtype.
716    ///
717    /// Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching
718    /// pre-existing loader behaviour. Backends override this to go straight
719    /// from raw bytes into a native half-precision buffer (e.g. Metal with
720    /// `FERRUM_METAL_DTYPE=f16`), avoiding the transient 2× RAM spike.
721    fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
722        let data = src_dtype.to_f32_vec(raw);
723        Self::from_slice(&data)
724    }
725
726    // (The Phase A3 unified `gemm_quant(QuantWeights, QuantKind)` stub
727    // that used to live here is superseded by the `load_quant` /
728    // `gemm_quant(QuantStore)` pair earlier in this trait — same idea,
729    // but the store hides the per-kind buffer layout so callers don't
730    // have to construct a per-kind `QuantWeights<'_, Self>` packet.)
731}
732
733// ════════════════════════════════════════════════════════════════════════
734// BackendPagedKv capability (vLLM-style paged KV cache + paged attention)
735// ════════════════════════════════════════════════════════════════════════
736//
737// Paged KV pool with block-table indirection, plus the paged attention
738// kernel variants that read through that indirection. CUDA + Metal both
739// implement the real kernels; CPU `impl BackendPagedKv for CpuBackend {}`
740// inherits unsupported defaults.
741
742/// Capability-trait for backends that support paged KV cache + paged attention.
743pub trait BackendPagedKv: Backend {
744    /// Whether this backend has a paged-KV decode path
745    /// (`paged_decode_attention` etc.). Currently true for Metal, false
746    /// for CPU. Used to decide the default of `FERRUM_METAL_PAGED_KV` —
747    /// the `serve` path should opt in automatically when supported so
748    /// users get the bench-quality concurrent-decode numbers without
749    /// having to learn the flag.
750    fn supports_paged_kv() -> bool {
751        false
752    }
753    /// Pre-populate the per-slot device-pointer scratch arrays used by
754    /// the batched kernels (`kv_cache_append_batched_per_cache` and
755    /// `flash_attention_batched_per_cache`). Required by the CUDA-graph
756    /// capture path: the captured graph contains only kernel launches
757    /// (no captured `memcpy_htod`), so the device scratch must be fresh
758    /// when the graph replays.
759    ///
760    /// Caller passes flat layer-major slices: `k_caches[li * m + i]` and
761    /// `v_caches[li * m + i]`. Backend extracts each cache's device
762    /// pointer and writes into its corresponding slot in the device
763    /// scratch via SYNCHRONOUS memcpy (not captured by stream capture).
764    ///
765    /// CUDA-only; other backends fall through to the default
766    /// `unsupported` and the caller skips the population call.
767    fn populate_batched_pointers(
768        _ctx: &mut Self::Context,
769        _k_caches: &[&Self::Buffer],
770        _v_caches: &[&Self::Buffer],
771        _num_layers: usize,
772        _m: usize,
773    ) -> Result<()> {
774        Err(FerrumError::unsupported(
775            "populate_batched_pointers not implemented for this backend",
776        ))
777    }
778    /// Paged-KV variant of [`Self::split_qkv_norm_rope_into_cache`].
779    ///
780    /// Same fused split + qk-norm + RoPE, but K/V are written into a
781    /// paged pool `[num_blocks, kv_heads, block_size, head_dim]`
782    /// indexed via `block_table[logical_block]` → physical_block.
783    /// Q still goes to head-major scratch.
784    ///
785    /// Default returns Unsupported. Backends that lack a paged kernel
786    /// keep using the contiguous variant.
787    /// `qkv_byte_offset` / `q_out_byte_offset` let the caller pass a
788    /// slice of a larger batched buffer (used by the multi-seq paged
789    /// path in `decode_batch_internal`). For single-seq dispatch they
790    /// should be 0.
791    #[allow(clippy::too_many_arguments)]
792    fn split_qkv_norm_rope_into_paged_cache(
793        _ctx: &mut Self::Context,
794        _qkv: &Self::Buffer,
795        _qkv_byte_offset: u64,
796        _q_norm_w: &Self::Buffer,
797        _k_norm_w: &Self::Buffer,
798        _cos: &Self::Buffer,
799        _sin: &Self::Buffer,
800        _q_out: &mut Self::Buffer,
801        _q_out_byte_offset: u64,
802        _cache_k: &mut Self::Buffer,
803        _cache_v: &mut Self::Buffer,
804        _block_table: &Self::Buffer,
805        _tokens: usize,
806        _q_heads: usize,
807        _kv_heads: usize,
808        _head_dim: usize,
809        _pos_offset: usize,
810        _eps: f32,
811        _qk_mode: i32,
812        _cache_len: usize,
813        _block_size: usize,
814        _max_num_blocks_per_seq: usize,
815    ) -> Result<()> {
816        Err(FerrumError::unsupported(
817            "split_qkv_norm_rope_into_paged_cache not implemented for this backend",
818        ))
819    }
820    /// Paged-KV variant of [`Self::flash_attention`].
821    ///
822    /// Decode (`q_len == 1`):
823    ///   `q`/`out`: `[num_seqs, num_heads, head_dim]` (token-major)
824    ///
825    /// Causal prefill (`q_len > 1`, single seq):
826    ///   `q`/`out`: `[num_heads, q_len, head_dim]` (head-major — the
827    ///              layout produced by `split_qkv_norm_rope_into_paged_cache`)
828    ///   The kernel applies a per-q-token causal mask using
829    ///   `context_lens[seq]` as the FINAL kv_len (= `pos_offset + q_len`):
830    ///   token i sees positions `[0, context_lens - q_len + 1 + i)`.
831    ///
832    /// Common to both:
833    ///   `k_pool`/`v_pool`: `[num_blocks, num_kv_heads, block_size, head_dim]`
834    ///   `block_tables`: `[num_seqs, max_num_blocks_per_seq]` u32
835    ///   `context_lens`: `[num_seqs]` u32
836    ///
837    /// Backends without a paged kernel return Unsupported; callers are
838    /// expected to fall back to contiguous KV.
839    #[allow(clippy::too_many_arguments)]
840    fn paged_decode_attention(
841        _ctx: &mut Self::Context,
842        _q: &Self::Buffer,
843        _k_pool: &Self::Buffer,
844        _v_pool: &Self::Buffer,
845        _out: &mut Self::Buffer,
846        _block_tables: &Self::Buffer,
847        _context_lens: &Self::Buffer,
848        _num_seqs: usize,
849        _num_heads: usize,
850        _num_kv_heads: usize,
851        _head_dim: usize,
852        _block_size: usize,
853        _max_num_blocks_per_seq: usize,
854        _q_len: usize,
855    ) -> Result<()> {
856        Err(FerrumError::unsupported(
857            "paged_decode_attention not implemented for this backend",
858        ))
859    }
860    /// Capability: does this backend implement
861    /// `split_qkv_norm_rope_into_paged_cache_varlen` and
862    /// `paged_varlen_attention`? Required by the unified mixed-batch
863    /// forward path used by `LlamaFamilyModel::unified_forward`. Default
864    /// false; backends that ship the varlen kernels override.
865    fn supports_varlen_qkv() -> bool {
866        false
867    }
868    /// Varlen variant of [`Self::split_qkv_norm_rope_into_paged_cache`].
869    ///
870    /// Single launch covering ALL sequences in the batch. Reads
871    /// `pos_offsets[seq]`, `cu_seqlens_q[seq]`, and the per-seq
872    /// block_table from device buffers — graph-capturable (the per-iter
873    /// state is in buffers, not kernel scalars). Replaces the per-item
874    /// dispatch loop in `unified_forward_layer` with one call.
875    ///
876    /// Layouts:
877    /// - `qkv`: `[m_total, q_dim + 2 * kv_dim]` token-major
878    /// - `q_out`: `[m_total, q_heads, head_dim]` token-major (matches
879    ///   what `paged_varlen_attention` reads)
880    /// - `cache_k` / `cache_v`: paged pool same as `paged_varlen_attention`
881    /// - `cu_seqlens_q`: `[num_seqs + 1]` u32 prefix sum
882    /// - `pos_offsets`: `[num_seqs]` u32, starting kv_pos per seq
883    /// - `block_tables`: `[num_seqs, max_blocks_per_seq]` i32 stacked
884    #[allow(clippy::too_many_arguments)]
885    fn split_qkv_norm_rope_into_paged_cache_varlen(
886        _ctx: &mut Self::Context,
887        _qkv: &Self::Buffer,
888        _q_norm_w: &Self::Buffer,
889        _k_norm_w: &Self::Buffer,
890        _cos: &Self::Buffer,
891        _sin: &Self::Buffer,
892        _q_out: &mut Self::Buffer,
893        _cache_k: &mut Self::Buffer,
894        _cache_v: &mut Self::Buffer,
895        _cu_seqlens_q: &Self::Buffer,
896        _pos_offsets: &Self::Buffer,
897        _block_tables: &Self::Buffer,
898        _num_seqs: usize,
899        _m_total: usize,
900        _q_heads: usize,
901        _kv_heads: usize,
902        _head_dim: usize,
903        _eps: f32,
904        _qk_mode: i32,
905        _block_size: usize,
906        _max_blocks_per_seq: usize,
907    ) -> Result<()> {
908        Err(FerrumError::unsupported(
909            "split_qkv_norm_rope_into_paged_cache_varlen not implemented for this backend",
910        ))
911    }
912    /// Variable-length paged attention with GQA + causal mask.
913    ///
914    /// Supports a unified mixed batch where each sequence contributes
915    /// 1 (decode) or N (prefill chunk) query tokens — the workhorse for
916    /// chunked-prefill. See `kernels/paged_varlen_attention.cu` for the
917    /// kernel itself.
918    ///
919    /// Layouts:
920    /// - `q` / `out`: `[total_q_tokens, num_heads, head_dim]` (token-
921    ///   major, FP16). `total_q_tokens` = `cu_seqlens_q[num_seqs]`.
922    /// - `k_pool` / `v_pool`: paged block pool, layout matches
923    ///   `paged_decode_attention`.
924    /// - `cu_seqlens_q`: `[num_seqs + 1]` u32 prefix sum, with
925    ///   `cu_seqlens_q[0] = 0` and `cu_seqlens_q[num_seqs] = total_q_tokens`.
926    /// - `pos_offsets`: `[num_seqs]` u32, the starting absolute KV
927    ///   position of each seq's first q token (= prior `kv_len`).
928    /// - `block_tables`: `[num_seqs, max_num_blocks_per_seq]` i32 grid.
929    ///
930    /// Each query token attends causally to all KV positions
931    /// `[0, pos_offsets[s] + local_idx]`.
932    #[allow(clippy::too_many_arguments)]
933    fn paged_varlen_attention(
934        _ctx: &mut Self::Context,
935        _q: &Self::Buffer,
936        _k_pool: &Self::Buffer,
937        _v_pool: &Self::Buffer,
938        _out: &mut Self::Buffer,
939        _cu_seqlens_q: &Self::Buffer,
940        _pos_offsets: &Self::Buffer,
941        _block_tables: &Self::Buffer,
942        _num_seqs: usize,
943        _total_q_tokens: usize,
944        _max_kv_len: usize,
945        _num_heads: usize,
946        _num_kv_heads: usize,
947        _head_dim: usize,
948        _block_size: usize,
949        _max_num_blocks_per_seq: usize,
950    ) -> Result<()> {
951        Err(FerrumError::unsupported(
952            "paged_varlen_attention not implemented for this backend",
953        ))
954    }
955
956    /// Opt-in vLLM FlashAttention-2 FFI path for FA-layout paged KV.
957    ///
958    /// This is intentionally separate from [`Self::paged_varlen_attention`]:
959    /// it needs the final per-sequence KV lengths (`seq_lens`) and an explicit
960    /// LSE scratch buffer because the external FA2 runner writes softmax LSE.
961    /// Default returns Err(unsupported); CUDA overrides when a runtime shim is
962    /// provided via `FERRUM_FA2_DIRECT_FFI_SHIM`.
963    #[allow(clippy::too_many_arguments)]
964    fn paged_varlen_attention_fa2_ffi(
965        _ctx: &mut Self::Context,
966        _q: &Self::Buffer,
967        _k_pool: &Self::Buffer,
968        _v_pool: &Self::Buffer,
969        _out: &mut Self::Buffer,
970        _lse: &mut Self::Buffer,
971        _cu_seqlens_q: &Self::Buffer,
972        _seq_lens: &Self::Buffer,
973        _block_tables: &Self::Buffer,
974        _num_seqs: usize,
975        _total_q_tokens: usize,
976        _max_q_len: usize,
977        _max_kv_len: usize,
978        _num_heads: usize,
979        _num_kv_heads: usize,
980        _head_dim: usize,
981        _block_size: usize,
982        _max_num_blocks_per_seq: usize,
983    ) -> Result<()> {
984        Err(FerrumError::unsupported(
985            "paged_varlen_attention_fa2_ffi not implemented for this backend",
986        ))
987    }
988
989    /// Batched paged decode attention — multi-seq, single token per seq.
990    /// Faster path for the unified_forward layer when m_total == num_seqs
991    /// (every item is a single-token decode). Skips the cu_seqlens_q
992    /// linear scan that `paged_varlen_attention` does in the fully-mixed
993    /// case.
994    ///
995    /// Layouts:
996    ///   q              : [num_seqs, num_q_heads, head_dim]
997    ///   k_pool/v_pool  : paged pool (same as paged_varlen)
998    ///   block_tables   : [num_seqs, max_num_blocks_per_seq]
999    ///   valid_kv_lens  : [num_seqs] — current kv_len per seq
1000    ///   out            : [num_seqs, num_q_heads, head_dim]
1001    ///
1002    /// Default returns Err(unsupported); CUDA backend overrides.
1003    #[allow(clippy::too_many_arguments)]
1004    fn paged_batched_decode_attention(
1005        _ctx: &mut Self::Context,
1006        _q: &Self::Buffer,
1007        _k_pool: &Self::Buffer,
1008        _v_pool: &Self::Buffer,
1009        _out: &mut Self::Buffer,
1010        _block_tables: &Self::Buffer,
1011        _valid_kv_lens: &Self::Buffer,
1012        _num_seqs: usize,
1013        _max_kv_len: usize,
1014        _num_heads: usize,
1015        _num_kv_heads: usize,
1016        _head_dim: usize,
1017        _block_size: usize,
1018        _max_num_blocks_per_seq: usize,
1019    ) -> Result<()> {
1020        Err(FerrumError::unsupported(
1021            "paged_batched_decode_attention not implemented for this backend",
1022        ))
1023    }
1024
1025    /// Capability: backend has vLLM-layout paged KV write kernels and the
1026    /// `paged_attention_v2` decode kernel. Models that opt into this layout
1027    /// at construction time (via `FERRUM_USE_VLLM_PAGED_ATTN=1`) must
1028    /// dispatch ALL paged writes and reads through the `_vllm` variants —
1029    /// the layouts are not compatible. Default `false`.
1030    fn supports_vllm_paged_attn() -> bool {
1031        false
1032    }
1033
1034    /// vLLM-layout variant of
1035    /// [`Self::split_qkv_norm_rope_into_paged_cache`]. K/V are written in
1036    /// vLLM's `paged_attention_v2` layout: K is
1037    /// `[num_blocks, kv_heads, head_dim/x, block_size, x]` (x = 16/sizeof(elem)),
1038    /// V is `[num_blocks, kv_heads, head_dim, block_size]`. Q output and
1039    /// every other argument matches the non-vllm variant exactly so the
1040    /// model layer can swap dispatchers based on a single flag.
1041    #[allow(clippy::too_many_arguments)]
1042    fn split_qkv_norm_rope_into_paged_cache_vllm(
1043        _ctx: &mut Self::Context,
1044        _qkv: &Self::Buffer,
1045        _qkv_byte_offset: u64,
1046        _q_norm_w: &Self::Buffer,
1047        _k_norm_w: &Self::Buffer,
1048        _cos: &Self::Buffer,
1049        _sin: &Self::Buffer,
1050        _q_out: &mut Self::Buffer,
1051        _q_out_byte_offset: u64,
1052        _cache_k: &mut Self::Buffer,
1053        _cache_v: &mut Self::Buffer,
1054        _block_table: &Self::Buffer,
1055        _tokens: usize,
1056        _q_heads: usize,
1057        _kv_heads: usize,
1058        _head_dim: usize,
1059        _pos_offset: usize,
1060        _eps: f32,
1061        _qk_mode: i32,
1062        _cache_len: usize,
1063        _block_size: usize,
1064        _max_num_blocks_per_seq: usize,
1065    ) -> Result<()> {
1066        Err(FerrumError::unsupported(
1067            "split_qkv_norm_rope_into_paged_cache_vllm not implemented for this backend",
1068        ))
1069    }
1070
1071    /// vLLM-layout variant of
1072    /// [`Self::split_qkv_norm_rope_into_paged_cache_varlen`]. Same signature
1073    /// — only the K/V cache layout changes.
1074    #[allow(clippy::too_many_arguments)]
1075    fn split_qkv_norm_rope_into_paged_cache_varlen_vllm(
1076        _ctx: &mut Self::Context,
1077        _qkv: &Self::Buffer,
1078        _q_norm_w: &Self::Buffer,
1079        _k_norm_w: &Self::Buffer,
1080        _cos: &Self::Buffer,
1081        _sin: &Self::Buffer,
1082        _q_out: &mut Self::Buffer,
1083        _cache_k: &mut Self::Buffer,
1084        _cache_v: &mut Self::Buffer,
1085        _cu_seqlens_q: &Self::Buffer,
1086        _pos_offsets: &Self::Buffer,
1087        _block_tables: &Self::Buffer,
1088        _num_seqs: usize,
1089        _m_total: usize,
1090        _q_heads: usize,
1091        _kv_heads: usize,
1092        _head_dim: usize,
1093        _eps: f32,
1094        _qk_mode: i32,
1095        _block_size: usize,
1096        _max_blocks_per_seq: usize,
1097    ) -> Result<()> {
1098        Err(FerrumError::unsupported(
1099            "split_qkv_norm_rope_into_paged_cache_varlen_vllm not implemented for this backend",
1100        ))
1101    }
1102
1103    /// vLLM `paged_attention_v2` — multi-partition split-K decode attention
1104    /// reading the vLLM K/V layout. `q_len` is implicitly 1 (decode only;
1105    /// vLLM's v2 kernel does not support q_len > 1). `max_seq_len` is the
1106    /// max kv_len across the batch — used to size the partition reduction.
1107    #[allow(clippy::too_many_arguments)]
1108    fn paged_decode_attention_v2(
1109        _ctx: &mut Self::Context,
1110        _q: &Self::Buffer,
1111        _k_pool: &Self::Buffer,
1112        _v_pool: &Self::Buffer,
1113        _out: &mut Self::Buffer,
1114        _block_tables: &Self::Buffer,
1115        _context_lens: &Self::Buffer,
1116        _num_seqs: usize,
1117        _num_heads: usize,
1118        _num_kv_heads: usize,
1119        _head_dim: usize,
1120        _block_size: usize,
1121        _max_num_blocks_per_seq: usize,
1122        _max_seq_len: usize,
1123    ) -> Result<()> {
1124        Err(FerrumError::unsupported(
1125            "paged_decode_attention_v2 not implemented for this backend",
1126        ))
1127    }
1128
1129    /// q_len>1 prefill/chunk-prefill attention over vLLM-layout paged KV.
1130    /// This keeps cache layout consistent when `FERRUM_USE_VLLM_PAGED_ATTN=1`
1131    /// and the prompt path writes K/V in the layout consumed later by
1132    /// `paged_decode_attention_v2`.
1133    #[allow(clippy::too_many_arguments)]
1134    fn paged_varlen_attention_vllm_layout(
1135        _ctx: &mut Self::Context,
1136        _q: &Self::Buffer,
1137        _k_pool: &Self::Buffer,
1138        _v_pool: &Self::Buffer,
1139        _out: &mut Self::Buffer,
1140        _block_tables: &Self::Buffer,
1141        _context_lens: &Self::Buffer,
1142        _num_seqs: usize,
1143        _num_heads: usize,
1144        _num_kv_heads: usize,
1145        _head_dim: usize,
1146        _block_size: usize,
1147        _max_num_blocks_per_seq: usize,
1148        _q_len: usize,
1149    ) -> Result<()> {
1150        Err(FerrumError::unsupported(
1151            "paged_varlen_attention_vllm_layout not implemented for this backend",
1152        ))
1153    }
1154
1155    /// Variable-length paged attention over vLLM-layout paged KV.
1156    ///
1157    /// Unlike [`Self::paged_varlen_attention_vllm_layout`], this accepts the
1158    /// same varlen index tensors as [`Self::paged_varlen_attention`] and writes
1159    /// token-major output directly. It is the unified mixed-batch companion for
1160    /// `split_qkv_norm_rope_into_paged_cache_varlen_vllm`.
1161    #[allow(clippy::too_many_arguments)]
1162    fn paged_varlen_attention_vllm(
1163        _ctx: &mut Self::Context,
1164        _q: &Self::Buffer,
1165        _k_pool: &Self::Buffer,
1166        _v_pool: &Self::Buffer,
1167        _out: &mut Self::Buffer,
1168        _cu_seqlens_q: &Self::Buffer,
1169        _pos_offsets: &Self::Buffer,
1170        _block_tables: &Self::Buffer,
1171        _num_seqs: usize,
1172        _total_q_tokens: usize,
1173        _max_kv_len: usize,
1174        _num_heads: usize,
1175        _num_kv_heads: usize,
1176        _head_dim: usize,
1177        _block_size: usize,
1178        _max_num_blocks_per_seq: usize,
1179    ) -> Result<()> {
1180        Err(FerrumError::unsupported(
1181            "paged_varlen_attention_vllm not implemented for this backend",
1182        ))
1183    }
1184
1185    /// Q-tiled vLLM-layout varlen attention. `tile_seqs` and `tile_starts`
1186    /// describe a compact list of q-token tiles, avoiding empty grid blocks
1187    /// for mixed batches that contain both long prefill items and q_len=1
1188    /// decode items. Semantics match [`Self::paged_varlen_attention_vllm`].
1189    #[allow(clippy::too_many_arguments)]
1190    fn paged_varlen_attention_vllm_tiled_q4(
1191        _ctx: &mut Self::Context,
1192        _q: &Self::Buffer,
1193        _k_pool: &Self::Buffer,
1194        _v_pool: &Self::Buffer,
1195        _out: &mut Self::Buffer,
1196        _cu_seqlens_q: &Self::Buffer,
1197        _pos_offsets: &Self::Buffer,
1198        _block_tables: &Self::Buffer,
1199        _tile_seqs: &Self::Buffer,
1200        _tile_starts: &Self::Buffer,
1201        _num_tiles: usize,
1202        _max_kv_len: usize,
1203        _num_heads: usize,
1204        _num_kv_heads: usize,
1205        _head_dim: usize,
1206        _block_size: usize,
1207        _max_num_blocks_per_seq: usize,
1208    ) -> Result<()> {
1209        Err(FerrumError::unsupported(
1210            "paged_varlen_attention_vllm_tiled_q4 not implemented for this backend",
1211        ))
1212    }
1213}
1214
1215// ════════════════════════════════════════════════════════════════════════
1216// Capability bundles — readable type aliases over the supertrait set
1217// ════════════════════════════════════════════════════════════════════════
1218//
1219// Models declare what they need via these bundles instead of spelling out
1220// every supertrait. Rust auto-derives the impl via blanket impls below,
1221// so any backend that satisfies the underlying supertraits automatically
1222// becomes a `LlmBackend` / `QuantLlmBackend` / `MoeLlmBackend`.
1223
1224/// Minimum capability set for a decoder-only LLM: the core compute trait
1225/// plus paged-KV cache + graph-capture support. Every concrete backend
1226/// (CUDA / Metal / CPU) satisfies this.
1227pub trait LlmBackend: Backend + BackendGraph + BackendPagedKv {}
1228impl<T> LlmBackend for T where T: Backend + BackendGraph + BackendPagedKv {}
1229
1230/// LLM backend that also supports quantized weight loading (GPTQ Marlin
1231/// for CUDA; GGUF k-quant for Metal). Required by models that hold
1232/// `Box<dyn Linear<B>>` where the Linear impl might be a quant variant.
1233pub trait QuantLlmBackend: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
1234impl<T> QuantLlmBackend for T where T: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
1235
1236/// MoE-capable LLM backend: adds the fused MoE routing + post-op kernels
1237/// to the quant LLM bundle. Required by Qwen3-MoE / future MoE models.
1238pub trait MoeLlmBackend: QuantLlmBackend + BackendMoeFused {}
1239impl<T> MoeLlmBackend for T where T: QuantLlmBackend + BackendMoeFused {}
1240
1241// ════════════════════════════════════════════════════════════════════════
1242// KV cache dtype axis (dim 5 of the 5-dimension architecture)
1243// ════════════════════════════════════════════════════════════════════════
1244//
1245// Each model's KV cache has its own precision independent of the model's
1246// compute precision. vLLM 0.6+ ships INT8 / FP8 KV caches that halve KV
1247// memory at small (<1%) accuracy hit. Today ferrum's KV is hardcoded
1248// FP16 on CUDA / Metal — to support INT8/FP8 KV in a future PR, the
1249// type system needs an explicit axis.
1250//
1251// Phase 4 scope: scaffolding only. All concrete backends impl
1252// `BackendKvDtype<KvFp16>` so existing models keep working unchanged.
1253// Future PR: implement BackendKvDtype<KvInt8> on CUDA + a new model
1254// type-parameter `K: KvDtypeKind` to wire it through.
1255
1256// `KvDtypeKind` + `KvFp16` / `KvBf16` / `KvInt8` / `KvFp8` markers moved
1257// to `ferrum_interfaces::kv_dtype` (no GPU deps, so the right place is
1258// the contract crate). Re-exported here so existing callers keep
1259// compiling against `crate::backend::KvFp16` etc.
1260pub use ferrum_interfaces::kv_dtype::{KvBf16, KvDtypeKind, KvFp16, KvFp8, KvInt8};
1261
1262/// Capability-trait for backends that can store + read a KV cache of
1263/// type `K`.
1264///
1265/// The two associated types carry the K-specific storage shape:
1266///   - `KvBuffer`: per-layer K/V element storage. For `K = KvFp16` it
1267///     is the backend's normal `Self::Buffer` (FP16). For `K = KvInt8`
1268///     it is the backend's INT8 buffer (e.g. `CudaSlice<i8>` on CUDA).
1269///   - `KvScales`: per-token-per-kv-head scales. For `K = KvFp16` this
1270///     is the unit type `()` (no scales). For `K = KvInt8` / `KvFp8`
1271///     it is a backend-specific FP16 buffer.
1272///
1273/// Models that want INT8 KV use:
1274///   `where B: BackendKvDtype<KvInt8>`
1275/// — the buffers in `KvCache<B, KvInt8>` are then `CudaSlice<i8>` and
1276/// `CudaSlice<f16>`, distinct from the FP16 path's `Self::Buffer`.
1277pub trait BackendKvDtype<K: KvDtypeKind>: BackendPagedKv {
1278    /// Per-layer K/V element storage.
1279    type KvBuffer: Send + Sync;
1280    /// Per-token per-kv-head scale storage. `()` for FP16 (no scales).
1281    type KvScales: Send + Sync + Default;
1282}
1283
1284/// INT8 KV cache operations (Dim 5).
1285///
1286/// `BackendKvDtype<KvInt8>` only declares the storage types; it does not
1287/// know how to write INT8 K/V into a paged pool or run paged decode
1288/// attention against an INT8 cache. Those launchers live here so the
1289/// model layer can call them through a single `B: BackendInt8KvOps` bound
1290/// without dropping into backend-specific code.
1291///
1292/// Today only `CudaBackend` provides a real implementation (delegating to
1293/// [`crate::int8_kv::launch_int8_kv_cache_append`] and
1294/// [`crate::int8_kv::launch_int8_paged_decode_attention`]). Other backends
1295/// inherit the default `unimplemented!()` body — the registry factory
1296/// rejects `(Device::CPU/Metal, KvCacheDtype::Int8)` before the model
1297/// gets a chance to call into these.
1298#[allow(clippy::too_many_arguments)]
1299pub trait BackendInt8KvOps: Backend + BackendKvDtype<KvInt8> {
1300    /// Allocate the per-layer INT8 paged cache for one sequence.
1301    /// Default panics — backends without INT8 support never reach this
1302    /// path (factory rejects (Cpu/Metal, Int8) before ensure_kv runs).
1303    fn alloc_paged_int8_layer(
1304        _max_blocks_per_seq: usize,
1305        _block_size: usize,
1306        _num_kv_heads: usize,
1307        _head_dim: usize,
1308    ) -> KvCacheQuant<Self, KvInt8> {
1309        unimplemented!("alloc_paged_int8_layer not supported on this backend")
1310    }
1311
1312    /// Append `tokens` FP16 K/V values into the paged INT8 pool.
1313    /// `paged_block_indices` is the host-side mirror of the per-seq
1314    /// logical→physical block table (already populated at `ensure_kv` time
1315    /// — see `KvCacheQuant::paged_block_indices`). Passing the host slice
1316    /// avoids a per-token D2H + sync barrier; backend computes the slot
1317    /// mapping host-side, async-H2D's it, and chains the append kernel
1318    /// on the same stream — fully overlapping with prior work.
1319    /// `cache_len_before` is the current number of valid tokens; the
1320    /// backend quantizes FP16 → INT8 with per-(token, kv-head) FP16 scale
1321    /// and writes both into the layer's INT8 / scale buffers.
1322    fn int8_kv_append_paged(
1323        _ctx: &mut Self::Context,
1324        _k_in: &Self::Buffer,
1325        _v_in: &Self::Buffer,
1326        _layer_k: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
1327        _layer_v: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
1328        _layer_k_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
1329        _layer_v_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
1330        _paged_block_indices: &[u32],
1331        _cache_len_before: usize,
1332        _tokens: usize,
1333        _block_size: usize,
1334        _num_kv_heads: usize,
1335        _head_dim: usize,
1336    ) -> Result<()> {
1337        Err(FerrumError::unsupported(
1338            "int8_kv_append_paged not implemented for this backend",
1339        ))
1340    }
1341
1342    /// Run paged decode attention reading from an INT8 cache. Q is FP16,
1343    /// output is FP16; the kernel dequantizes K/V on the fly using the
1344    /// per-token scales. `valid_kv_len` is the post-append cache length
1345    /// (i.e. the kernel attends over `[0, valid_kv_len)` tokens).
1346    fn int8_paged_decode_attention(
1347        _ctx: &mut Self::Context,
1348        _q: &Self::Buffer,
1349        _layer_k: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
1350        _layer_v: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
1351        _layer_k_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
1352        _layer_v_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
1353        _block_table: &Self::Buffer,
1354        _output: &mut Self::Buffer,
1355        _num_q_heads: usize,
1356        _num_kv_heads: usize,
1357        _head_dim: usize,
1358        _valid_kv_len: usize,
1359        _block_size: usize,
1360        _scale: f32,
1361    ) -> Result<()> {
1362        Err(FerrumError::unsupported(
1363            "int8_paged_decode_attention not implemented for this backend",
1364        ))
1365    }
1366}
1367
1368// Cpu/Metal NOT impl `BackendInt8KvOps` — the trait pivot to
1369// `KvLayer<B>` means `KvInt8: KvLayer<B>` only holds where
1370// `B: BackendInt8KvOps`, so `LlamaFamilyModel<CpuBackend, KvInt8>` is a
1371// compile error (no INT8 KvLayer impl satisfies it). Type system
1372// enforces the constraint without runtime stubs.