Skip to main content

ferrum_kernels/backend/
traits.rs

1//! Core Backend trait — the single abstraction over CUDA / Metal / CPU.
2
3use ferrum_types::{FerrumError, Result};
4
5/// Quantization flavour discriminator for `Backend::gemm_quant`.
6///
7/// Distinct schemes need distinct kernels. Carried as a parameter so the
8/// Backend trait does not explode with one method per quantization type.
9#[derive(Clone, Debug)]
10pub enum QuantKind {
11    /// GPTQ: group-wise int4/int8 with scales + zeros (asymmetric) + optional g_idx.
12    Gptq {
13        bits: u32,
14        group_size: usize,
15        desc_act: bool,
16    },
17    /// AWQ: activation-aware int4 with scales + zeros, different packing from GPTQ.
18    Awq { bits: u32, group_size: usize },
19    /// GGUF: one of k-quants / legacy quants, fully specified by the inner type.
20    Gguf { quant_type: GgufQuantType },
21}
22
23/// GGUF quantization sub-type (expand as kernels are added).
24#[derive(Clone, Copy, Debug)]
25pub enum GgufQuantType {
26    Q4_0,
27    Q4_1,
28    Q4K,
29    Q5K,
30    Q6K,
31    Q8_0,
32}
33
34/// Packed quantized weight buffers passed to `Backend::gemm_quant`.
35///
36/// Not every field is used by every `QuantKind` — e.g. GGUF packs scales
37/// inside `qweight`, so `scales` / `zeros` may be dummies. The Backend
38/// implementation is expected to validate the shape for the kind it handles.
39pub struct QuantWeights<'a, B: Backend> {
40    pub qweight: &'a B::Buffer,
41    pub scales: Option<&'a B::Buffer>,
42    pub zeros: Option<&'a B::Buffer>,
43    pub g_idx: Option<&'a B::Buffer>,
44}
45
46/// Collective-op reduction kind for TP all_reduce.
47#[derive(Clone, Copy, Debug)]
48pub enum ReduceOp {
49    Sum,
50    Max,
51    Min,
52}
53
54/// Configuration for attention dispatch.
55#[derive(Clone, Debug)]
56pub struct AttnConfig {
57    pub num_heads: usize,
58    pub num_kv_heads: usize,
59    pub head_dim: usize,
60    pub causal: bool,
61    pub scale: f32,
62    /// Stride (in rows) between head blocks in the KV buffer.
63    /// `0` means contiguous (use `kv_len`, legacy behaviour).
64    /// Set to `cache_capacity` when flashing against a pre-allocated cache
65    /// that only has `kv_len` valid slots out of `cache_capacity`.
66    pub kv_seq_stride: usize,
67    /// Sliding-window attention size (Mistral v0.1, Gemma).
68    /// `0` = disabled (full causal attention).
69    /// `w > 0` = each query position attends to the previous `w` KV positions
70    ///            (still bounded by `causal` + `pos_offset + qi + 1` as the upper end).
71    pub sliding_window: usize,
72}
73
74impl Default for AttnConfig {
75    fn default() -> Self {
76        Self {
77            num_heads: 0,
78            num_kv_heads: 0,
79            head_dim: 0,
80            causal: false,
81            scale: 1.0,
82            kv_seq_stride: 0,
83            sliding_window: 0,
84        }
85    }
86}
87
88// Note: `TransformerConfig` / `AttnType` / `MlpType` / `RopeConfig` used to
89// live here when `ModelRunner` needed a generic model config. They're now
90// per-model (e.g. `Qwen3Config` in `ferrum-models::models::qwen3`) so each
91// model can carry exactly the architecture parameters it cares about.
92// Backend trait stays model-agnostic.
93
94/// Per-layer KV cache. Each model owns its own `Vec<KvCache<B>>` per sequence.
95pub struct KvCache<B: Backend> {
96    pub k: B::Buffer,
97    pub v: B::Buffer,
98    pub len: usize,
99    pub capacity: usize,
100    pub num_kv_heads: usize,
101    pub head_dim: usize,
102}
103
104/// The core abstraction over CUDA / Metal / CPU.
105///
106/// Key design: operations take a `&mut Self::Context` which accumulates work.
107///   - **CPU**: Context is `()` — ops execute immediately.
108///   - **Metal**: Context is a `CommandBuffer` — ops encode into it, flushed on `sync()`.
109///   - **CUDA**: Context is a `CudaStream` — ops launch on the stream, synced on `sync()`.
110///
111/// `layer_forward` passes the context through all ops in a layer.
112/// `ModelRunner` calls `sync()` only when it needs results (e.g., reading logits).
113pub trait Backend: Send + Sync + Sized + 'static {
114    type Buffer: Send + Sync;
115
116    /// Execution context that accumulates GPU work.
117    ///   - CPU: `()` (no-op, ops execute inline)
118    ///   - Metal: wraps a CommandBuffer
119    ///   - CUDA: wraps a CudaStream
120    type Context;
121
122    /// Opaque per-backend GPTQ weight representation.
123    ///   - CPU: dequantized f32 weights (run as regular GEMM)
124    ///   - Metal: `()` — unsupported; `gemm_gptq` errors
125    ///   - CUDA: `MarlinWeight` — pre-repacked tiles + permuted scales
126    ///
127    /// Each backend repacks raw GPTQ tensors (qweight/scales/qzeros, all
128    /// i32/f16) into its preferred format at model load time, so inference
129    /// doesn't pay the repack cost per forward pass.
130    type GptqStore: Send + Sync;
131
132    /// Create a new execution context (begin accumulating work).
133    fn new_context() -> Self::Context;
134
135    /// Flush accumulated work and wait for completion.
136    /// CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
137    fn sync(ctx: &mut Self::Context);
138
139    // ── Graph capture / replay (CUDA only) ──────────────────────────────
140    //
141    // Decode-loop optimization: eliminate per-kernel launch overhead by
142    // capturing the full step as a CUDA graph and replaying. CPU/Metal
143    // have no equivalent — defaults return `unsupported`.
144    //
145    // Flow per decode step:
146    //   1. Caller: `set_decode_state(ctx, token, step)` — memcpy to dev bufs
147    //   2. Try `replay_last_graph(ctx)`:
148    //        - Ok(true):  graph replayed, skip eager forward
149    //        - Ok(false): no captured graph yet, run eager
150    //        - Err(_):    not supported, run eager
151    //   3. If running eager and in capture window:
152    //      - `set_dev_state_mode(ctx, true)` so kernels use _dyn variants
153    //      - `begin_graph_capture(ctx)`
154    //      - run forward
155    //      - `end_graph_capture(ctx)` — stores graph on ctx internally
156    //      - `set_dev_state_mode(ctx, false)` — restore scalar kernels
157
158    /// Update per-step dynamic state (token id, step/pos). Fast (3x memcpy).
159    fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) {}
160
161    /// Toggle between scalar-arg kernels (normal) and `_dyn` kernels that
162    /// read their dynamic scalar args from device memory (graph-friendly).
163    fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) {}
164
165    /// Begin stream capture. Subsequent kernel launches are recorded into
166    /// a pending graph instead of executing eagerly.
167    fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
168        Err(FerrumError::unsupported("graph capture not supported"))
169    }
170
171    /// End stream capture and install the captured graph as this context's
172    /// "last graph" for future `replay_last_graph` calls.
173    fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
174        Err(FerrumError::unsupported("graph capture not supported"))
175    }
176
177    /// Replay the last captured graph. Returns `Ok(false)` if no graph
178    /// is cached; caller should run eager.
179    fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool> {
180        Ok(false)
181    }
182
183    /// Drop the cached decode graph — required when the KV cache it
184    /// was captured against is about to be freed (e.g. request release),
185    /// since the graph holds raw device pointers into that cache.
186    fn reset_graph(_ctx: &mut Self::Context) {}
187
188    // ── GPTQ (INT4 quantization) ────────────────────────────────────────
189    //
190    // Two-step: load (once per weight) → gemm (per forward). The store
191    // holds whatever backend-specific format is fastest; caller code
192    // (GptqLinear) is dtype-agnostic.
193
194    /// Repack raw GPTQ tensors into the backend's preferred format.
195    /// Called once per layer at model load time.
196    ///
197    /// Inputs are host-side slices (CPU memory) — the loader reads from
198    /// safetensors and hands them off; each backend uploads + repacks
199    /// per its own strategy. `bits` is typically 4; `group_size` is
200    /// typically 128.
201    #[allow(clippy::too_many_arguments)]
202    fn load_gptq(
203        _qweight: &[i32],
204        _scales: &[f32],
205        _qzeros: &[i32],
206        _g_idx: Option<&[i32]>,
207        _bits: u32,
208        _group_size: usize,
209        _k: usize,
210        _n: usize,
211    ) -> Result<Self::GptqStore> {
212        Err(FerrumError::unsupported(
213            "load_gptq not implemented for this backend",
214        ))
215    }
216
217    /// GEMM with pre-loaded GPTQ weights.
218    /// `out[m, n] = a[m, k] @ dequant(weight)^T`
219    fn gemm_gptq(
220        _ctx: &mut Self::Context,
221        _a: &Self::Buffer,
222        _weight: &Self::GptqStore,
223        _out: &mut Self::Buffer,
224        _m: usize,
225    ) -> Result<()> {
226        Err(FerrumError::unsupported(
227            "gemm_gptq not implemented for this backend",
228        ))
229    }
230
231    // ── GEMM ────────────────────────────────────────────────────────────
232
233    fn gemm(
234        ctx: &mut Self::Context,
235        a: &Self::Buffer,
236        b: &Self::Buffer,
237        out: &mut Self::Buffer,
238        m: usize,
239        n: usize,
240        k: usize,
241    );
242
243    // ── Norms ───────────────────────────────────────────────────────────
244
245    fn rms_norm(
246        ctx: &mut Self::Context,
247        x: &Self::Buffer,
248        w: &Self::Buffer,
249        eps: f32,
250        out: &mut Self::Buffer,
251        tokens: usize,
252        dim: usize,
253    );
254
255    fn fused_add_rms_norm(
256        ctx: &mut Self::Context,
257        residual: &mut Self::Buffer,
258        x: &Self::Buffer,
259        w: &Self::Buffer,
260        eps: f32,
261        out: &mut Self::Buffer,
262        tokens: usize,
263        dim: usize,
264    );
265
266    // ── Attention ───────────────────────────────────────────────────────
267
268    fn flash_attention(
269        ctx: &mut Self::Context,
270        q: &Self::Buffer,
271        k: &Self::Buffer,
272        v: &Self::Buffer,
273        out: &mut Self::Buffer,
274        batch: usize,
275        q_len: usize,
276        kv_len: usize,
277        pos_offset: usize,
278        cfg: &AttnConfig,
279    );
280
281    /// Multi-Head Latent Attention — DeepSeek V2 / V3's compressed-KV
282    /// attention variant. Extension point only; no backend implements it
283    /// yet. DeepSeek V3 landing in Phase D/E will fill this in.
284    ///
285    /// `q`: full Q `[batch, num_heads, q_len, head_dim]`
286    /// `kv_compressed`: latent KV `[batch, kv_len, kv_lora_rank]`
287    /// `kv_rope`: per-position rope-applied key heads `[batch, kv_len, qk_rope_head_dim]`
288    /// `out`: `[batch, num_heads, q_len, head_dim]`
289    #[allow(clippy::too_many_arguments)]
290    fn mla_attention(
291        _ctx: &mut Self::Context,
292        _q: &Self::Buffer,
293        _kv_compressed: &Self::Buffer,
294        _kv_rope: &Self::Buffer,
295        _out: &mut Self::Buffer,
296        _batch: usize,
297        _q_len: usize,
298        _kv_len: usize,
299        _pos_offset: usize,
300        _cfg: &AttnConfig,
301        _kv_lora_rank: usize,
302        _qk_rope_head_dim: usize,
303    ) -> Result<()> {
304        Err(FerrumError::unsupported(
305            "mla_attention not implemented for this backend; required by \
306             DeepSeek V2/V3 (Phase D/E)",
307        ))
308    }
309
310    // ── Element-wise ────────────────────────────────────────────────────
311    //
312    // Models use `add_inplace` for residual updates and `copy_slice` for the
313    // row-extraction step in prefill. Offset-free copy / non-inplace add are
314    // not needed by the current Model-as-Code path; they can return later if
315    // a model actually requires them.
316
317    /// Copy `len` floats from `src[src_offset..]` to `dst[dst_offset..]`.
318    ///
319    /// Needed for Qwen3Model::prefill to pluck the last token's hidden state
320    /// out of `residual[seq_len, h]` without round-tripping through host RAM.
321    /// `Backend::copy` is the offset-free variant; `copy_slice` additionally
322    /// supports non-zero source and destination offsets.
323    fn copy_slice(
324        ctx: &mut Self::Context,
325        src: &Self::Buffer,
326        src_offset: usize,
327        dst: &mut Self::Buffer,
328        dst_offset: usize,
329        len: usize,
330    );
331
332    // ── Embedding ───────────────────────────────────────────────────────
333
334    fn embedding_lookup(
335        ctx: &mut Self::Context,
336        table: &Self::Buffer,
337        ids: &[u32],
338        out: &mut Self::Buffer,
339        dim: usize,
340    );
341
342    // ── Transformer-specific fused ops ─────────────────────────────────
343    // These avoid CPU round-trips for data layout transformations.
344
345    /// Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers.
346    /// Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
347    fn split_qkv(
348        ctx: &mut Self::Context,
349        qkv: &Self::Buffer,
350        q: &mut Self::Buffer,
351        k: &mut Self::Buffer,
352        v: &mut Self::Buffer,
353        tokens: usize,
354        q_dim: usize,
355        kv_dim: usize,
356    );
357
358    /// Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im],
359    /// then compute SiLU(gate) * up → out [tokens, im].
360    fn fused_silu_mul_split(
361        ctx: &mut Self::Context,
362        gate_up: &Self::Buffer,
363        out: &mut Self::Buffer,
364        tokens: usize,
365        im: usize,
366    );
367
368    /// Fused QK-norm + RoPE + transpose-to-head-major.
369    ///
370    /// `mode` selects the operation:
371    ///   0 = transpose only (typical for V, which needs no norm and no RoPE)
372    ///   1 = per-head RMS norm + RoPE + transpose  (Q/K with QK-norm, Qwen3)
373    ///   2 = RoPE + transpose                       (Q/K without QK-norm, Llama/Mistral)
374    ///
375    /// input:   `[tokens, heads, head_dim]`  (token-major, output of split_qkv)
376    /// output:  `[heads, tokens, head_dim]`  (head-major, ready for flash_attn / kv_cache_append)
377    ///
378    /// `pos_offset` is the position of token 0 (decode uses current seq len;
379    /// prefill uses 0). Within the batch, positions are taken as `pos_offset + i`.
380    ///
381    /// This is the primary attention-input preparation op. Backends that have a
382    /// fused kernel (Metal's `qk_norm_rope_transpose_f32`) will be dramatically
383    /// faster than composing norm + rope + transpose separately; the CPU
384    /// fallback lowers to the individual ops.
385    #[allow(clippy::too_many_arguments)]
386    fn qk_norm_rope(
387        ctx: &mut Self::Context,
388        input: &Self::Buffer,
389        norm_w: &Self::Buffer,
390        cos: &Self::Buffer,
391        sin: &Self::Buffer,
392        output: &mut Self::Buffer,
393        tokens: usize,
394        heads: usize,
395        head_dim: usize,
396        pos_offset: usize,
397        eps: f32,
398        mode: i32,
399    );
400
401    /// Append new K/V into a pre-allocated head-major cache buffer.
402    ///
403    /// `cache_k` / `cache_v`: `[nkv, capacity, hd]` (head-major, pre-allocated)
404    /// `new_k_head_major` / `new_v_head_major`: `[nkv, new_tokens, hd]`
405    ///   — produced directly by `qk_norm_rope`, no extra transpose needed.
406    ///
407    /// In-place append at slot `[nkv, cache_len..cache_len+new_tokens, hd]`.
408    /// Caller owns `cache_len` bookkeeping.
409    #[allow(clippy::too_many_arguments)]
410    fn kv_cache_append_head_major(
411        ctx: &mut Self::Context,
412        cache_k: &mut Self::Buffer,
413        cache_v: &mut Self::Buffer,
414        cache_len: usize,
415        cache_capacity: usize,
416        new_k_head_major: &Self::Buffer,
417        new_v_head_major: &Self::Buffer,
418        new_tokens: usize,
419        nkv: usize,
420        hd: usize,
421    );
422
423    /// Transpose [heads, tokens, dim] → [tokens, heads, dim].
424    /// Called after `flash_attention` to restore token-major layout for O-proj.
425    fn transpose_head_to_token(
426        ctx: &mut Self::Context,
427        src: &Self::Buffer,
428        dst: &mut Self::Buffer,
429        tokens: usize,
430        heads: usize,
431        dim: usize,
432    );
433
434    /// residual[i] += x[i] (in-place)
435    fn add_inplace(
436        ctx: &mut Self::Context,
437        residual: &mut Self::Buffer,
438        x: &Self::Buffer,
439        len: usize,
440    );
441
442    /// Broadcast bias add: `data[r, c] += bias[c]` for every row.
443    /// Required by Bert / Clip / Whisper whose linear projections carry a bias.
444    fn add_bias(
445        ctx: &mut Self::Context,
446        data: &mut Self::Buffer,
447        bias: &Self::Buffer,
448        rows: usize,
449        cols: usize,
450    );
451
452    /// Full LayerNorm (mean + variance normalisation + affine), distinct from
453    /// the `rms_norm` used by Llama-family decoders.
454    ///   `out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]`
455    /// Where `mean` and `var` are reduced over the last dim (cols).
456    #[allow(clippy::too_many_arguments)]
457    fn layer_norm(
458        ctx: &mut Self::Context,
459        x: &Self::Buffer,
460        gamma: &Self::Buffer,
461        beta: &Self::Buffer,
462        eps: f32,
463        out: &mut Self::Buffer,
464        tokens: usize,
465        dim: usize,
466    );
467
468    /// Element-wise GELU activation (erf-based, matches PyTorch default).
469    fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
470
471    // ── Buffer management (context-free) ────────────────────────────────
472
473    fn alloc(len: usize) -> Self::Buffer;
474    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
475    fn from_slice(data: &[f32]) -> Self::Buffer;
476
477    // ── Quantized GEMM (Phase A3 stubs) ─────────────────────────────────
478    //
479    // Backends override the kinds they actually support (e.g. Metal will
480    // implement Gptq first; CUDA will implement Gptq + Awq via Marlin).
481    // Default impl returns an `unsupported` error so missing kernels surface
482    // as clean runtime errors instead of silent wrong output.
483
484    /// GEMM with packed-quantized B matrix. `m`/`n`/`k` describe the dense
485    /// equivalent (`[m,n] = [m,k] @ [k,n]^T`).
486    #[allow(clippy::too_many_arguments)]
487    fn gemm_quant(
488        _ctx: &mut Self::Context,
489        _a: &Self::Buffer,
490        _weights: &QuantWeights<'_, Self>,
491        _out: &mut Self::Buffer,
492        _m: usize,
493        _n: usize,
494        _k: usize,
495        kind: &QuantKind,
496    ) -> Result<()> {
497        Err(FerrumError::unsupported(format!(
498            "gemm_quant({kind:?}) not implemented for this backend"
499        )))
500    }
501
502    // ── TP collective ops (Phase A3 stubs) ──────────────────────────────
503    //
504    // Default impl is single-rank no-op: `world_size = 1`, `rank = 0`, and
505    // the collective ops are identity. Multi-GPU backends (future
506    // CudaBackend + NCCL) override these. Model code can call
507    // `B::all_reduce_sum(...)` unconditionally; single-GPU paths pay zero.
508
509    fn world_size(_ctx: &Self::Context) -> usize {
510        1
511    }
512    fn rank(_ctx: &Self::Context) -> usize {
513        0
514    }
515    fn all_reduce(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp) {
516        // single-rank: no-op
517    }
518    fn all_gather(
519        _ctx: &mut Self::Context,
520        _local: &Self::Buffer,
521        _global: &mut Self::Buffer,
522        _local_len: usize,
523    ) {
524        // single-rank: no-op (caller is expected to handle the degenerate
525        // case or arrange for `local == global`)
526    }
527    fn broadcast(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize) {
528        // single-rank: no-op
529    }
530}