ferrum_kernels/backend/traits.rs
1//! Core Backend trait — the single abstraction over CUDA / Metal / CPU.
2
3use ferrum_types::{FerrumError, Result};
4
5/// Quantization flavour discriminator for `Backend::gemm_quant`.
6///
7/// Distinct schemes need distinct kernels. Carried as a parameter so the
8/// Backend trait does not explode with one method per quantization type.
9#[derive(Clone, Debug)]
10pub enum QuantKind {
11 /// GPTQ: group-wise int4/int8 with scales + zeros (asymmetric) + optional g_idx.
12 Gptq {
13 bits: u32,
14 group_size: usize,
15 desc_act: bool,
16 },
17 /// AWQ: activation-aware int4 with scales + zeros, different packing from GPTQ.
18 Awq { bits: u32, group_size: usize },
19 /// GGUF: one of k-quants / legacy quants, fully specified by the inner type.
20 Gguf { quant_type: GgufQuantType },
21}
22
23/// GGUF quantization sub-type (expand as kernels are added).
24#[derive(Clone, Copy, Debug)]
25pub enum GgufQuantType {
26 Q4_0,
27 Q4_1,
28 Q4K,
29 Q5K,
30 Q6K,
31 Q8_0,
32}
33
34/// Packed quantized weight buffers passed to `Backend::gemm_quant`.
35///
36/// Not every field is used by every `QuantKind` — e.g. GGUF packs scales
37/// inside `qweight`, so `scales` / `zeros` may be dummies. The Backend
38/// implementation is expected to validate the shape for the kind it handles.
39pub struct QuantWeights<'a, B: Backend> {
40 pub qweight: &'a B::Buffer,
41 pub scales: Option<&'a B::Buffer>,
42 pub zeros: Option<&'a B::Buffer>,
43 pub g_idx: Option<&'a B::Buffer>,
44}
45
46/// Collective-op reduction kind for TP all_reduce.
47#[derive(Clone, Copy, Debug)]
48pub enum ReduceOp {
49 Sum,
50 Max,
51 Min,
52}
53
54/// Configuration for attention dispatch.
55#[derive(Clone, Debug)]
56pub struct AttnConfig {
57 pub num_heads: usize,
58 pub num_kv_heads: usize,
59 pub head_dim: usize,
60 pub causal: bool,
61 pub scale: f32,
62 /// Stride (in rows) between head blocks in the KV buffer.
63 /// `0` means contiguous (use `kv_len`, legacy behaviour).
64 /// Set to `cache_capacity` when flashing against a pre-allocated cache
65 /// that only has `kv_len` valid slots out of `cache_capacity`.
66 pub kv_seq_stride: usize,
67 /// Sliding-window attention size (Mistral v0.1, Gemma).
68 /// `0` = disabled (full causal attention).
69 /// `w > 0` = each query position attends to the previous `w` KV positions
70 /// (still bounded by `causal` + `pos_offset + qi + 1` as the upper end).
71 pub sliding_window: usize,
72}
73
74impl Default for AttnConfig {
75 fn default() -> Self {
76 Self {
77 num_heads: 0,
78 num_kv_heads: 0,
79 head_dim: 0,
80 causal: false,
81 scale: 1.0,
82 kv_seq_stride: 0,
83 sliding_window: 0,
84 }
85 }
86}
87
88// Note: `TransformerConfig` / `AttnType` / `MlpType` / `RopeConfig` used to
89// live here when `ModelRunner` needed a generic model config. They're now
90// per-model (e.g. `Qwen3Config` in `ferrum-models::models::qwen3`) so each
91// model can carry exactly the architecture parameters it cares about.
92// Backend trait stays model-agnostic.
93
94/// Per-layer KV cache. Each model owns its own `Vec<KvCache<B>>` per sequence.
95pub struct KvCache<B: Backend> {
96 pub k: B::Buffer,
97 pub v: B::Buffer,
98 pub len: usize,
99 pub capacity: usize,
100 pub num_kv_heads: usize,
101 pub head_dim: usize,
102}
103
104/// The core abstraction over CUDA / Metal / CPU.
105///
106/// Key design: operations take a `&mut Self::Context` which accumulates work.
107/// - **CPU**: Context is `()` — ops execute immediately.
108/// - **Metal**: Context is a `CommandBuffer` — ops encode into it, flushed on `sync()`.
109/// - **CUDA**: Context is a `CudaStream` — ops launch on the stream, synced on `sync()`.
110///
111/// `layer_forward` passes the context through all ops in a layer.
112/// `ModelRunner` calls `sync()` only when it needs results (e.g., reading logits).
113pub trait Backend: Send + Sync + Sized + 'static {
114 type Buffer: Send + Sync;
115
116 /// Execution context that accumulates GPU work.
117 /// - CPU: `()` (no-op, ops execute inline)
118 /// - Metal: wraps a CommandBuffer
119 /// - CUDA: wraps a CudaStream
120 type Context;
121
122 /// Opaque per-backend GPTQ weight representation.
123 /// - CPU: dequantized f32 weights (run as regular GEMM)
124 /// - Metal: `()` — unsupported; `gemm_gptq` errors
125 /// - CUDA: `MarlinWeight` — pre-repacked tiles + permuted scales
126 ///
127 /// Each backend repacks raw GPTQ tensors (qweight/scales/qzeros, all
128 /// i32/f16) into its preferred format at model load time, so inference
129 /// doesn't pay the repack cost per forward pass.
130 type GptqStore: Send + Sync;
131
132 /// Create a new execution context (begin accumulating work).
133 fn new_context() -> Self::Context;
134
135 /// Flush accumulated work and wait for completion.
136 /// CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
137 fn sync(ctx: &mut Self::Context);
138
139 // ── Graph capture / replay (CUDA only) ──────────────────────────────
140 //
141 // Decode-loop optimization: eliminate per-kernel launch overhead by
142 // capturing the full step as a CUDA graph and replaying. CPU/Metal
143 // have no equivalent — defaults return `unsupported`.
144 //
145 // Flow per decode step:
146 // 1. Caller: `set_decode_state(ctx, token, step)` — memcpy to dev bufs
147 // 2. Try `replay_last_graph(ctx)`:
148 // - Ok(true): graph replayed, skip eager forward
149 // - Ok(false): no captured graph yet, run eager
150 // - Err(_): not supported, run eager
151 // 3. If running eager and in capture window:
152 // - `set_dev_state_mode(ctx, true)` so kernels use _dyn variants
153 // - `begin_graph_capture(ctx)`
154 // - run forward
155 // - `end_graph_capture(ctx)` — stores graph on ctx internally
156 // - `set_dev_state_mode(ctx, false)` — restore scalar kernels
157
158 /// Update per-step dynamic state (token id, step/pos). Fast (3x memcpy).
159 fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) {}
160
161 /// Toggle between scalar-arg kernels (normal) and `_dyn` kernels that
162 /// read their dynamic scalar args from device memory (graph-friendly).
163 fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) {}
164
165 /// Begin stream capture. Subsequent kernel launches are recorded into
166 /// a pending graph instead of executing eagerly.
167 fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
168 Err(FerrumError::unsupported("graph capture not supported"))
169 }
170
171 /// End stream capture and install the captured graph as this context's
172 /// "last graph" for future `replay_last_graph` calls.
173 fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
174 Err(FerrumError::unsupported("graph capture not supported"))
175 }
176
177 /// Replay the last captured graph. Returns `Ok(false)` if no graph
178 /// is cached; caller should run eager.
179 fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool> {
180 Ok(false)
181 }
182
183 /// Drop the cached decode graph — required when the KV cache it
184 /// was captured against is about to be freed (e.g. request release),
185 /// since the graph holds raw device pointers into that cache.
186 fn reset_graph(_ctx: &mut Self::Context) {}
187
188 // ── GPTQ (INT4 quantization) ────────────────────────────────────────
189 //
190 // Two-step: load (once per weight) → gemm (per forward). The store
191 // holds whatever backend-specific format is fastest; caller code
192 // (GptqLinear) is dtype-agnostic.
193
194 /// Repack raw GPTQ tensors into the backend's preferred format.
195 /// Called once per layer at model load time.
196 ///
197 /// Inputs are host-side slices (CPU memory) — the loader reads from
198 /// safetensors and hands them off; each backend uploads + repacks
199 /// per its own strategy. `bits` is typically 4; `group_size` is
200 /// typically 128.
201 #[allow(clippy::too_many_arguments)]
202 fn load_gptq(
203 _qweight: &[i32],
204 _scales: &[f32],
205 _qzeros: &[i32],
206 _g_idx: Option<&[i32]>,
207 _bits: u32,
208 _group_size: usize,
209 _k: usize,
210 _n: usize,
211 ) -> Result<Self::GptqStore> {
212 Err(FerrumError::unsupported(
213 "load_gptq not implemented for this backend",
214 ))
215 }
216
217 /// GEMM with pre-loaded GPTQ weights.
218 /// `out[m, n] = a[m, k] @ dequant(weight)^T`
219 fn gemm_gptq(
220 _ctx: &mut Self::Context,
221 _a: &Self::Buffer,
222 _weight: &Self::GptqStore,
223 _out: &mut Self::Buffer,
224 _m: usize,
225 ) -> Result<()> {
226 Err(FerrumError::unsupported(
227 "gemm_gptq not implemented for this backend",
228 ))
229 }
230
231 // ── GEMM ────────────────────────────────────────────────────────────
232
233 fn gemm(
234 ctx: &mut Self::Context,
235 a: &Self::Buffer,
236 b: &Self::Buffer,
237 out: &mut Self::Buffer,
238 m: usize,
239 n: usize,
240 k: usize,
241 );
242
243 // ── Norms ───────────────────────────────────────────────────────────
244
245 fn rms_norm(
246 ctx: &mut Self::Context,
247 x: &Self::Buffer,
248 w: &Self::Buffer,
249 eps: f32,
250 out: &mut Self::Buffer,
251 tokens: usize,
252 dim: usize,
253 );
254
255 fn fused_add_rms_norm(
256 ctx: &mut Self::Context,
257 residual: &mut Self::Buffer,
258 x: &Self::Buffer,
259 w: &Self::Buffer,
260 eps: f32,
261 out: &mut Self::Buffer,
262 tokens: usize,
263 dim: usize,
264 );
265
266 // ── Attention ───────────────────────────────────────────────────────
267
268 fn flash_attention(
269 ctx: &mut Self::Context,
270 q: &Self::Buffer,
271 k: &Self::Buffer,
272 v: &Self::Buffer,
273 out: &mut Self::Buffer,
274 batch: usize,
275 q_len: usize,
276 kv_len: usize,
277 pos_offset: usize,
278 cfg: &AttnConfig,
279 );
280
281 /// Multi-Head Latent Attention — DeepSeek V2 / V3's compressed-KV
282 /// attention variant. Extension point only; no backend implements it
283 /// yet. DeepSeek V3 landing in Phase D/E will fill this in.
284 ///
285 /// `q`: full Q `[batch, num_heads, q_len, head_dim]`
286 /// `kv_compressed`: latent KV `[batch, kv_len, kv_lora_rank]`
287 /// `kv_rope`: per-position rope-applied key heads `[batch, kv_len, qk_rope_head_dim]`
288 /// `out`: `[batch, num_heads, q_len, head_dim]`
289 #[allow(clippy::too_many_arguments)]
290 fn mla_attention(
291 _ctx: &mut Self::Context,
292 _q: &Self::Buffer,
293 _kv_compressed: &Self::Buffer,
294 _kv_rope: &Self::Buffer,
295 _out: &mut Self::Buffer,
296 _batch: usize,
297 _q_len: usize,
298 _kv_len: usize,
299 _pos_offset: usize,
300 _cfg: &AttnConfig,
301 _kv_lora_rank: usize,
302 _qk_rope_head_dim: usize,
303 ) -> Result<()> {
304 Err(FerrumError::unsupported(
305 "mla_attention not implemented for this backend; required by \
306 DeepSeek V2/V3 (Phase D/E)",
307 ))
308 }
309
310 // ── Element-wise ────────────────────────────────────────────────────
311 //
312 // Models use `add_inplace` for residual updates and `copy_slice` for the
313 // row-extraction step in prefill. Offset-free copy / non-inplace add are
314 // not needed by the current Model-as-Code path; they can return later if
315 // a model actually requires them.
316
317 /// Copy `len` floats from `src[src_offset..]` to `dst[dst_offset..]`.
318 ///
319 /// Needed for Qwen3Model::prefill to pluck the last token's hidden state
320 /// out of `residual[seq_len, h]` without round-tripping through host RAM.
321 /// `Backend::copy` is the offset-free variant; `copy_slice` additionally
322 /// supports non-zero source and destination offsets.
323 fn copy_slice(
324 ctx: &mut Self::Context,
325 src: &Self::Buffer,
326 src_offset: usize,
327 dst: &mut Self::Buffer,
328 dst_offset: usize,
329 len: usize,
330 );
331
332 // ── Embedding ───────────────────────────────────────────────────────
333
334 fn embedding_lookup(
335 ctx: &mut Self::Context,
336 table: &Self::Buffer,
337 ids: &[u32],
338 out: &mut Self::Buffer,
339 dim: usize,
340 );
341
342 // ── Transformer-specific fused ops ─────────────────────────────────
343 // These avoid CPU round-trips for data layout transformations.
344
345 /// Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers.
346 /// Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
347 fn split_qkv(
348 ctx: &mut Self::Context,
349 qkv: &Self::Buffer,
350 q: &mut Self::Buffer,
351 k: &mut Self::Buffer,
352 v: &mut Self::Buffer,
353 tokens: usize,
354 q_dim: usize,
355 kv_dim: usize,
356 );
357
358 /// Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im],
359 /// then compute SiLU(gate) * up → out [tokens, im].
360 fn fused_silu_mul_split(
361 ctx: &mut Self::Context,
362 gate_up: &Self::Buffer,
363 out: &mut Self::Buffer,
364 tokens: usize,
365 im: usize,
366 );
367
368 /// Fused QK-norm + RoPE + transpose-to-head-major.
369 ///
370 /// `mode` selects the operation:
371 /// 0 = transpose only (typical for V, which needs no norm and no RoPE)
372 /// 1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3)
373 /// 2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)
374 ///
375 /// input: `[tokens, heads, head_dim]` (token-major, output of split_qkv)
376 /// output: `[heads, tokens, head_dim]` (head-major, ready for flash_attn / kv_cache_append)
377 ///
378 /// `pos_offset` is the position of token 0 (decode uses current seq len;
379 /// prefill uses 0). Within the batch, positions are taken as `pos_offset + i`.
380 ///
381 /// This is the primary attention-input preparation op. Backends that have a
382 /// fused kernel (Metal's `qk_norm_rope_transpose_f32`) will be dramatically
383 /// faster than composing norm + rope + transpose separately; the CPU
384 /// fallback lowers to the individual ops.
385 #[allow(clippy::too_many_arguments)]
386 fn qk_norm_rope(
387 ctx: &mut Self::Context,
388 input: &Self::Buffer,
389 norm_w: &Self::Buffer,
390 cos: &Self::Buffer,
391 sin: &Self::Buffer,
392 output: &mut Self::Buffer,
393 tokens: usize,
394 heads: usize,
395 head_dim: usize,
396 pos_offset: usize,
397 eps: f32,
398 mode: i32,
399 );
400
401 /// Append new K/V into a pre-allocated head-major cache buffer.
402 ///
403 /// `cache_k` / `cache_v`: `[nkv, capacity, hd]` (head-major, pre-allocated)
404 /// `new_k_head_major` / `new_v_head_major`: `[nkv, new_tokens, hd]`
405 /// — produced directly by `qk_norm_rope`, no extra transpose needed.
406 ///
407 /// In-place append at slot `[nkv, cache_len..cache_len+new_tokens, hd]`.
408 /// Caller owns `cache_len` bookkeeping.
409 #[allow(clippy::too_many_arguments)]
410 fn kv_cache_append_head_major(
411 ctx: &mut Self::Context,
412 cache_k: &mut Self::Buffer,
413 cache_v: &mut Self::Buffer,
414 cache_len: usize,
415 cache_capacity: usize,
416 new_k_head_major: &Self::Buffer,
417 new_v_head_major: &Self::Buffer,
418 new_tokens: usize,
419 nkv: usize,
420 hd: usize,
421 );
422
423 /// Transpose [heads, tokens, dim] → [tokens, heads, dim].
424 /// Called after `flash_attention` to restore token-major layout for O-proj.
425 fn transpose_head_to_token(
426 ctx: &mut Self::Context,
427 src: &Self::Buffer,
428 dst: &mut Self::Buffer,
429 tokens: usize,
430 heads: usize,
431 dim: usize,
432 );
433
434 /// residual[i] += x[i] (in-place)
435 fn add_inplace(
436 ctx: &mut Self::Context,
437 residual: &mut Self::Buffer,
438 x: &Self::Buffer,
439 len: usize,
440 );
441
442 /// Broadcast bias add: `data[r, c] += bias[c]` for every row.
443 /// Required by Bert / Clip / Whisper whose linear projections carry a bias.
444 fn add_bias(
445 ctx: &mut Self::Context,
446 data: &mut Self::Buffer,
447 bias: &Self::Buffer,
448 rows: usize,
449 cols: usize,
450 );
451
452 /// Full LayerNorm (mean + variance normalisation + affine), distinct from
453 /// the `rms_norm` used by Llama-family decoders.
454 /// `out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]`
455 /// Where `mean` and `var` are reduced over the last dim (cols).
456 #[allow(clippy::too_many_arguments)]
457 fn layer_norm(
458 ctx: &mut Self::Context,
459 x: &Self::Buffer,
460 gamma: &Self::Buffer,
461 beta: &Self::Buffer,
462 eps: f32,
463 out: &mut Self::Buffer,
464 tokens: usize,
465 dim: usize,
466 );
467
468 /// Element-wise GELU activation (erf-based, matches PyTorch default).
469 fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
470
471 // ── Buffer management (context-free) ────────────────────────────────
472
473 fn alloc(len: usize) -> Self::Buffer;
474 fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
475 fn from_slice(data: &[f32]) -> Self::Buffer;
476
477 // ── Quantized GEMM (Phase A3 stubs) ─────────────────────────────────
478 //
479 // Backends override the kinds they actually support (e.g. Metal will
480 // implement Gptq first; CUDA will implement Gptq + Awq via Marlin).
481 // Default impl returns an `unsupported` error so missing kernels surface
482 // as clean runtime errors instead of silent wrong output.
483
484 /// GEMM with packed-quantized B matrix. `m`/`n`/`k` describe the dense
485 /// equivalent (`[m,n] = [m,k] @ [k,n]^T`).
486 #[allow(clippy::too_many_arguments)]
487 fn gemm_quant(
488 _ctx: &mut Self::Context,
489 _a: &Self::Buffer,
490 _weights: &QuantWeights<'_, Self>,
491 _out: &mut Self::Buffer,
492 _m: usize,
493 _n: usize,
494 _k: usize,
495 kind: &QuantKind,
496 ) -> Result<()> {
497 Err(FerrumError::unsupported(format!(
498 "gemm_quant({kind:?}) not implemented for this backend"
499 )))
500 }
501
502 // ── TP collective ops (Phase A3 stubs) ──────────────────────────────
503 //
504 // Default impl is single-rank no-op: `world_size = 1`, `rank = 0`, and
505 // the collective ops are identity. Multi-GPU backends (future
506 // CudaBackend + NCCL) override these. Model code can call
507 // `B::all_reduce_sum(...)` unconditionally; single-GPU paths pay zero.
508
509 fn world_size(_ctx: &Self::Context) -> usize {
510 1
511 }
512 fn rank(_ctx: &Self::Context) -> usize {
513 0
514 }
515 fn all_reduce(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp) {
516 // single-rank: no-op
517 }
518 fn all_gather(
519 _ctx: &mut Self::Context,
520 _local: &Self::Buffer,
521 _global: &mut Self::Buffer,
522 _local_len: usize,
523 ) {
524 // single-rank: no-op (caller is expected to handle the degenerate
525 // case or arrange for `local == global`)
526 }
527 fn broadcast(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize) {
528 // single-rank: no-op
529 }
530}