ferrum_kernels/backend/traits.rs
1//! Core Backend trait — the single abstraction over CUDA / Metal / CPU.
2
3use ferrum_types::{FerrumError, Result};
4
5pub use super::capabilities::{
6 BackendCollective, BackendGraph, BackendMoeFused, BackendQuantGguf, BackendQuantMarlin,
7};
8pub use super::types::MoeRouting;
9use super::types::{AttnConfig, KvCacheQuant, SrcDtype};
10
11/// Maximum decode-graph layer count. Per-layer call sites that share
12/// graph-captured host staging arrays use this as the stride between
13/// distinct slots. CUDA-only invariant (other backends ignore the
14/// `slot` argument); 64 covers all current LLM families up to and
15/// including Llama-3-70B (80 layers — but 70B doesn't run on a single
16/// 4090 anyway, so 64 is safe in practice for v0.2).
17pub const MAX_LAYERS_FOR_GRAPH: usize = 64;
18
19// Note: `TransformerConfig` / `AttnType` / `MlpType` / `RopeConfig` used to
20// live here when `ModelRunner` needed a generic model config. They're now
21// per-model (e.g. `Qwen3Config` in `ferrum-models::models::qwen3`) so each
22// model can carry exactly the architecture parameters it cares about.
23// Backend trait stays model-agnostic.
24
25/// The core abstraction over CUDA / Metal / CPU.
26///
27/// Key design: operations take a `&mut Self::Context` which accumulates work.
28/// - **CPU**: Context is `()` — ops execute immediately.
29/// - **Metal**: Context is a `CommandBuffer` — ops encode into it, flushed on `sync()`.
30/// - **CUDA**: Context is a `CudaStream` — ops launch on the stream, synced on `sync()`.
31///
32/// `layer_forward` passes the context through all ops in a layer.
33/// `ModelRunner` calls `sync()` only when it needs results (e.g., reading logits).
34pub trait Backend: Send + Sync + Sized + 'static {
35 type Buffer: Send + Sync;
36
37 /// Execution context that accumulates GPU work.
38 /// - CPU: `()` (no-op, ops execute inline)
39 /// - Metal: wraps a CommandBuffer
40 /// - CUDA: wraps a CudaStream
41 type Context;
42
43 /// GPU-side timer scoped to this backend. See `super::timer` —
44 /// CPU: `Instant`; Metal: sync-wrap; CUDA: `cuEvent`.
45 /// PLAYBOOK § 1.1.
46 type Timer: super::timer::BackendTimer<Self>;
47
48 /// Factory for `Self::Timer` — exists so call sites that have a
49 /// `<B: Backend>` parameter can spawn a timer without importing the
50 /// concrete impl. PLAYBOOK § 1.2.
51 fn make_timer() -> Self::Timer;
52
53 /// Opaque per-backend GPTQ weight representation.
54 /// - CPU: dequantized f32 weights (run as regular GEMM)
55 /// - Metal: `()` — unsupported; `gemm_gptq` errors
56 // Note (Phase 3e/4 + Phase C):
57 // - `type QuantStore` (GGUF k-quant storage) was removed in Phase 3e/4
58 // — stacked-expert MoE GGUF goes through Box<dyn StackedExpertGgufLinear<Self>>
59 // returned by `load_quant_experts`.
60 // - `type GptqStore` (Marlin/dequant GPTQ storage) was removed in Phase C
61 // step 4e — stacked-expert Marlin MoE goes through
62 // Arc<dyn MarlinExpertStack<Self>> returned by `load_gptq_stacked`,
63 // and single-tensor GPTQ goes through Box<dyn Linear<Self>> returned
64 // by `load_gptq`. Adding a new Marlin-capable backend is purely a
65 // new MarlinExpertStack<NewBackend> impl — no Backend trait edits.
66
67 /// Create a new execution context (begin accumulating work).
68 fn new_context() -> Self::Context;
69
70 /// Flush accumulated work and wait for completion.
71 /// CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
72 fn sync(ctx: &mut Self::Context);
73
74 /// Prepare pending GPU work for a following host readback.
75 ///
76 /// Most backends either execute eagerly or synchronize as part of their
77 /// device-to-host copy. Metal shared-buffer reads use the CPU pointer
78 /// directly, so Metal must flush its command buffer before `to_vec`.
79 fn sync_before_host_readback(_ctx: &mut Self::Context) {}
80
81 /// Byte width of buffers returned by [`Self::alloc`].
82 ///
83 /// CUDA activation scratch is fp16, while Metal and CPU scratch are fp32.
84 /// Generic model code uses this for byte offsets into batched scratch
85 /// buffers without checking concrete backend types.
86 fn activation_elem_size_bytes() -> usize {
87 std::mem::size_of::<half::f16>()
88 }
89
90 /// Whether `LlamaFamilyModel::decode_batch_internal` may use its optimized
91 /// batched decode path on this backend.
92 ///
93 /// Backends that do not yet produce correct follow-up logits under
94 /// concurrent dense decode should override this to force the per-item
95 /// fallback until the optimized path is fixed.
96 fn supports_llama_family_batched_decode() -> bool {
97 true
98 }
99
100 // Graph capability moved to the `BackendGraph` supertrait at the end
101 // of this file. CUDA implements its overrides; Metal/CPU inherit
102 // unsupported defaults via empty `impl BackendGraph for X {}` blocks.
103
104 // ── GPTQ (INT4 quantization) ────────────────────────────────────────
105 //
106 // Two-step: load (once per weight) → gemm (per forward). The store
107 // holds whatever backend-specific format is fastest; caller code
108 // (GptqLinear) is dtype-agnostic.
109
110 /// Zero the first `len` elements of a Self::Buffer. CUDA path uses
111 /// cuMemsetD16Async; default returns unsupported.
112 fn zero_buffer(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize) -> Result<()> {
113 Err(FerrumError::unsupported(
114 "zero_buffer not implemented for this backend",
115 ))
116 }
117
118 /// Phase D step 2+3: unified typed allocator. Replaces per-dtype
119 /// `alloc_u32` / `alloc_typed_i32` / etc. The buffer is dtype-
120 /// tagged at the wrapper level (`CudaBuf::U32`, `MetalBuf` with
121 /// `Dtype::U32`, `CpuBuf::U32`), so reads/writes through `.as_<T>()`
122 /// accessors get the correct byte count automatically.
123 fn alloc_typed(dtype: super::Dtype, n: usize) -> Self::Buffer;
124
125 /// Upload typed host data — replaces `from_slice_i32` /
126 /// `from_slice_u32` etc. The host element type `T` carries its
127 /// `Dtype` via the `HostDtype` marker so dispatch in the impl
128 /// is a one-line `match T::DTYPE`.
129 fn from_slice_typed<T: super::HostDtype>(data: &[T]) -> Self::Buffer;
130
131 /// In-place typed write — replaces `write_u32` / `write_i32_into`
132 /// / `write_f32_into`. The buffer must already be dtype-tagged
133 /// matching `T::DTYPE` (typically alloc'd via `alloc_typed` or
134 /// `from_slice_typed`).
135 fn write_typed<T: super::HostDtype>(
136 ctx: &mut Self::Context,
137 dst: &mut Self::Buffer,
138 data: &[T],
139 );
140
141 // ── GEMM ────────────────────────────────────────────────────────────
142
143 fn gemm(
144 ctx: &mut Self::Context,
145 a: &Self::Buffer,
146 b: &Self::Buffer,
147 out: &mut Self::Buffer,
148 m: usize,
149 n: usize,
150 k: usize,
151 );
152
153 // ── Norms ───────────────────────────────────────────────────────────
154
155 fn rms_norm(
156 ctx: &mut Self::Context,
157 x: &Self::Buffer,
158 w: &Self::Buffer,
159 eps: f32,
160 out: &mut Self::Buffer,
161 tokens: usize,
162 dim: usize,
163 );
164
165 fn fused_add_rms_norm(
166 ctx: &mut Self::Context,
167 residual: &mut Self::Buffer,
168 x: &Self::Buffer,
169 w: &Self::Buffer,
170 eps: f32,
171 out: &mut Self::Buffer,
172 tokens: usize,
173 dim: usize,
174 );
175
176 // ── Attention ───────────────────────────────────────────────────────
177
178 fn flash_attention(
179 ctx: &mut Self::Context,
180 q: &Self::Buffer,
181 k: &Self::Buffer,
182 v: &Self::Buffer,
183 out: &mut Self::Buffer,
184 batch: usize,
185 q_len: usize,
186 kv_len: usize,
187 pos_offset: usize,
188 cfg: &AttnConfig,
189 );
190
191 /// Multi-Head Latent Attention — DeepSeek V2 / V3's compressed-KV
192 /// attention variant. Extension point only; no backend implements it
193 /// yet. DeepSeek V3 landing in Phase D/E will fill this in.
194 ///
195 /// `q`: full Q `[batch, num_heads, q_len, head_dim]`
196 /// `kv_compressed`: latent KV `[batch, kv_len, kv_lora_rank]`
197 /// `kv_rope`: per-position rope-applied key heads `[batch, kv_len, qk_rope_head_dim]`
198 /// `out`: `[batch, num_heads, q_len, head_dim]`
199 #[allow(clippy::too_many_arguments)]
200 fn mla_attention(
201 _ctx: &mut Self::Context,
202 _q: &Self::Buffer,
203 _kv_compressed: &Self::Buffer,
204 _kv_rope: &Self::Buffer,
205 _out: &mut Self::Buffer,
206 _batch: usize,
207 _q_len: usize,
208 _kv_len: usize,
209 _pos_offset: usize,
210 _cfg: &AttnConfig,
211 _kv_lora_rank: usize,
212 _qk_rope_head_dim: usize,
213 ) -> Result<()> {
214 Err(FerrumError::unsupported(
215 "mla_attention not implemented for this backend; required by \
216 DeepSeek V2/V3 (Phase D/E)",
217 ))
218 }
219
220 // ── Element-wise ────────────────────────────────────────────────────
221 //
222 // Models use `add_inplace` for residual updates and `copy_slice` for the
223 // row-extraction step in prefill. Offset-free copy / non-inplace add are
224 // not needed by the current Model-as-Code path; they can return later if
225 // a model actually requires them.
226
227 /// Copy `len` floats from `src[src_offset..]` to `dst[dst_offset..]`.
228 ///
229 /// Needed for Qwen3Model::prefill to pluck the last token's hidden state
230 /// out of `residual[seq_len, h]` without round-tripping through host RAM.
231 /// `Backend::copy` is the offset-free variant; `copy_slice` additionally
232 /// supports non-zero source and destination offsets.
233 fn copy_slice(
234 ctx: &mut Self::Context,
235 src: &Self::Buffer,
236 src_offset: usize,
237 dst: &mut Self::Buffer,
238 dst_offset: usize,
239 len: usize,
240 );
241
242 // ── Embedding ───────────────────────────────────────────────────────
243
244 fn embedding_lookup(
245 ctx: &mut Self::Context,
246 table: &Self::Buffer,
247 ids: &[u32],
248 out: &mut Self::Buffer,
249 dim: usize,
250 );
251
252 /// Device-buffer variant of `embedding_lookup` for graph-capturable
253 /// MoE routing — the gather step before phase-1 GEMM in
254 /// `moe_forward_bucketed`. The host-slice `embedding_lookup` does
255 /// `clone_htod(ids)` internally, which records stale host pointers
256 /// under CUDA Graph capture replay.
257 ///
258 /// `ids: &Self::Buffer` must be a device I32 buffer of `batch`
259 /// elements (e.g. `Qwen3MoeScratch::route_packed_idx_dev`).
260 /// `batch` is passed explicitly since a typed CudaBuf carries
261 /// its element count but the caller often wants a partial gather.
262 ///
263 /// Default impl: round-trip via `to_vec` + dispatch the host-slice
264 /// variant. CUDA overrides.
265 fn embedding_lookup_dev(
266 ctx: &mut Self::Context,
267 table: &Self::Buffer,
268 ids: &Self::Buffer,
269 out: &mut Self::Buffer,
270 batch: usize,
271 dim: usize,
272 ) {
273 // Default: round-trip. CUDA overrides with a direct device-arg
274 // kernel launch (no clone_htod).
275 let ids_host_f32 = Self::to_vec(ids, batch);
276 let ids_host_u32: Vec<u32> = ids_host_f32.iter().map(|x| x.to_bits()).collect();
277 Self::embedding_lookup(ctx, table, &ids_host_u32, out, dim);
278 }
279
280 // ── Transformer-specific fused ops ─────────────────────────────────
281 // These avoid CPU round-trips for data layout transformations.
282
283 /// Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers.
284 /// Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
285 fn split_qkv(
286 ctx: &mut Self::Context,
287 qkv: &Self::Buffer,
288 q: &mut Self::Buffer,
289 k: &mut Self::Buffer,
290 v: &mut Self::Buffer,
291 tokens: usize,
292 q_dim: usize,
293 kv_dim: usize,
294 );
295
296 /// Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im],
297 /// then compute SiLU(gate) * up → out [tokens, im].
298 fn fused_silu_mul_split(
299 ctx: &mut Self::Context,
300 gate_up: &Self::Buffer,
301 out: &mut Self::Buffer,
302 tokens: usize,
303 im: usize,
304 );
305
306 /// Fused QK-norm + RoPE + transpose-to-head-major.
307 ///
308 /// `mode` selects the operation:
309 /// 0 = transpose only (typical for V, which needs no norm and no RoPE)
310 /// 1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3)
311 /// 2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)
312 ///
313 /// input: `[tokens, heads, head_dim]` (token-major, output of split_qkv)
314 /// output: `[heads, tokens, head_dim]` (head-major, ready for flash_attn / kv_cache_append)
315 ///
316 /// `pos_offset` is the position of token 0 (decode uses current seq len;
317 /// prefill uses 0). Within the batch, positions are taken as `pos_offset + i`.
318 ///
319 /// This is the primary attention-input preparation op. Backends that have a
320 /// fused kernel (Metal's `qk_norm_rope_transpose_f32`) will be dramatically
321 /// faster than composing norm + rope + transpose separately; the CPU
322 /// fallback lowers to the individual ops.
323 #[allow(clippy::too_many_arguments)]
324 fn qk_norm_rope(
325 ctx: &mut Self::Context,
326 input: &Self::Buffer,
327 norm_w: &Self::Buffer,
328 cos: &Self::Buffer,
329 sin: &Self::Buffer,
330 output: &mut Self::Buffer,
331 tokens: usize,
332 heads: usize,
333 head_dim: usize,
334 pos_offset: usize,
335 eps: f32,
336 mode: i32,
337 );
338
339 /// Batched kv_cache_append across M caches in one launch. Each item
340 /// writes its (head-major) K-or-V row into its own cache at offset
341 /// read from `cache_lens[i]`. Replaces M sequential
342 /// `kv_cache_append_head_major` calls with a single dispatch.
343 ///
344 /// `new_data` layout: `[m, nkv, hd]` item-major (each item's slice
345 /// is contiguous, identical to the `k/v_normed_batched` produced by
346 /// `qk_norm_rope_batched_per_item`).
347 /// `caches`: per-cache `[nkv, capacity, hd]` head-major.
348 /// `cache_lens`: device buffer (u32 storage, length ≥ m). Caller
349 /// fills via `B::write_u32_into` BEFORE the call. Required for
350 /// CUDA-graph capture: the kernel reads from this stable device
351 /// buffer, so a captured graph can be replayed with new lens by
352 /// just rewriting the buffer between launches.
353 fn kv_cache_append_batched_per_cache(
354 _ctx: &mut Self::Context,
355 _caches: &[&Self::Buffer],
356 _new_data: &Self::Buffer,
357 _cache_lens: &Self::Buffer,
358 _capacity: usize,
359 _m: usize,
360 _nkv: usize,
361 _hd: usize,
362 _slot: usize,
363 ) -> Result<()> {
364 Err(FerrumError::unsupported(
365 "kv_cache_append_batched_per_cache not implemented for this backend",
366 ))
367 }
368
369 /// Batched flash_attention across M decode caches in one launch.
370 /// Replaces the per-item `flash_attention(q_len=1, ...)` × M
371 /// loop in the non-paged batched-decode path.
372 ///
373 /// API takes Vec<&Buffer> for the per-cache K/V buffers (each
374 /// `[nkv, capacity, hd]` head-major) plus host-side `kv_lens`.
375 /// Backends that implement it must extract per-cache device
376 /// pointers, build the device arrays the kernel needs, and launch
377 /// one kernel covering all M items.
378 ///
379 /// `q` layout: [m, nq, hd] item-major (matches the
380 /// `qk_norm_rope_batched_per_item` output for q_len=1).
381 /// `out` layout: [m, nq, hd] item-major — written directly into
382 /// the caller's batched attn_out buffer, no per-item copy needed.
383 ///
384 /// CUDA-only for now (kernel `batched_decode_attention` exists in
385 /// `kernels/batched_decode_attention.cu`).
386 /// `kv_lens`: device buffer (u32 storage, length ≥ m) — same
387 /// design as `kv_cache_append_batched_per_cache::cache_lens`.
388 fn flash_attention_batched_per_cache(
389 _ctx: &mut Self::Context,
390 _q: &Self::Buffer,
391 _k_caches: &[&Self::Buffer],
392 _v_caches: &[&Self::Buffer],
393 _kv_lens: &Self::Buffer,
394 _out: &mut Self::Buffer,
395 _nq: usize,
396 _nkv: usize,
397 _hd: usize,
398 _scale: f32,
399 _max_valid_kv: usize,
400 _capacity: usize,
401 _slot: usize,
402 ) -> Result<()> {
403 Err(FerrumError::unsupported(
404 "flash_attention_batched_per_cache not implemented for this backend",
405 ))
406 }
407
408 /// Batched per-item-position variant of `qk_norm_rope` for the
409 /// non-paged batched-decode path. Each of the `m` items has its own
410 /// absolute RoPE position (read from a device i32 buffer of length
411 /// `m`). Layout is item-major in *both* input and output:
412 ///
413 /// input [m, heads, head_dim]
414 /// output [m, heads, head_dim] (no head-major transpose)
415 ///
416 /// Item-major output keeps the per-item flash_attention slice
417 /// contiguous (`output[i * heads * head_dim ..]` is item i's whole
418 /// Q tensor in head-major-equivalent layout for q_len=1).
419 ///
420 /// Replaces the M sequential single-item launches in the existing
421 /// `forward_layer_batched_decode` path with one batched dispatch.
422 /// CUDA-only for now; other backends fall through to the default
423 /// `unsupported` and the caller falls back to the per-item loop.
424 fn qk_norm_rope_batched_per_item(
425 _ctx: &mut Self::Context,
426 _input: &Self::Buffer,
427 _norm_w: &Self::Buffer,
428 _cos: &Self::Buffer,
429 _sin: &Self::Buffer,
430 _output: &mut Self::Buffer,
431 _positions: &Self::Buffer,
432 _m: usize,
433 _heads: usize,
434 _head_dim: usize,
435 _eps: f32,
436 _mode: i32,
437 ) -> Result<()> {
438 Err(FerrumError::unsupported(
439 "qk_norm_rope_batched_per_item not implemented for this backend",
440 ))
441 }
442
443 /// Fused split-QKV + QK-norm + RoPE + head-major transpose.
444 ///
445 /// Single-dispatch replacement for the (`split_qkv` → 3× `qk_norm_rope`)
446 /// chain on the decode-attention prelude. Reads the linear-layer
447 /// fused-QKV output once and writes head-major Q/K/V directly into
448 /// attention scratch.
449 ///
450 /// `qkv` layout: `[tokens, q_heads*hd + 2*kv_heads*hd]`.
451 /// `q_out`: `[q_heads, tokens, hd]`. `k_out`/`v_out`: `[kv_heads, tokens, hd]`.
452 /// `qk_mode`: 1 = norm + half-split RoPE for Q/K (Qwen3 with QK-norm),
453 /// 2 = half-split RoPE only for Q/K,
454 /// 3 = interleaved RoPE only for Q/K (GGUF LLaMA / llama.cpp layout).
455 /// V always falls through to transpose-only.
456 ///
457 /// Default returns Unsupported. Backends that implement it are
458 /// expected to be dramatically faster than the four-dispatch chain.
459 #[allow(clippy::too_many_arguments)]
460 fn split_qkv_norm_rope(
461 _ctx: &mut Self::Context,
462 _qkv: &Self::Buffer,
463 _q_norm_w: &Self::Buffer,
464 _k_norm_w: &Self::Buffer,
465 _cos: &Self::Buffer,
466 _sin: &Self::Buffer,
467 _q_out: &mut Self::Buffer,
468 _k_out: &mut Self::Buffer,
469 _v_out: &mut Self::Buffer,
470 _tokens: usize,
471 _q_heads: usize,
472 _kv_heads: usize,
473 _head_dim: usize,
474 _pos_offset: usize,
475 _eps: f32,
476 _qk_mode: i32,
477 ) -> Result<()> {
478 Err(FerrumError::unsupported(
479 "split_qkv_norm_rope not implemented for this backend",
480 ))
481 }
482
483 /// Variant of [`Backend::split_qkv_norm_rope`] that writes the new
484 /// K and V directly into pre-allocated head-major KV cache buffers
485 /// at slot `[kv_heads, cache_len .. cache_len + tokens, hd]`.
486 /// Eliminates the trailing `kv_cache_append_head_major` dispatch on
487 /// the decode hot path. Q still lands in per-token head-major
488 /// scratch (flash-attention reads it as the query).
489 ///
490 /// Default returns Unsupported. Backends without the fused kernel
491 /// can keep using `split_qkv_norm_rope` + `kv_cache_append_head_major`.
492 #[allow(clippy::too_many_arguments)]
493 fn split_qkv_norm_rope_into_cache(
494 _ctx: &mut Self::Context,
495 _qkv: &Self::Buffer,
496 _q_norm_w: &Self::Buffer,
497 _k_norm_w: &Self::Buffer,
498 _cos: &Self::Buffer,
499 _sin: &Self::Buffer,
500 _q_out: &mut Self::Buffer,
501 _cache_k: &mut Self::Buffer,
502 _cache_v: &mut Self::Buffer,
503 _tokens: usize,
504 _q_heads: usize,
505 _kv_heads: usize,
506 _head_dim: usize,
507 _pos_offset: usize,
508 _eps: f32,
509 _qk_mode: i32,
510 _cache_len: usize,
511 _cache_capacity: usize,
512 ) -> Result<()> {
513 Err(FerrumError::unsupported(
514 "split_qkv_norm_rope_into_cache not implemented for this backend",
515 ))
516 }
517
518 // Phase D step 2: alloc_u32 / write_u32 deleted. Callers use the
519 // unified `alloc_typed(Dtype::U32, n)` + `write_typed(&[u32])` API
520 // declared above.
521
522 /// Append new K/V into a pre-allocated head-major cache buffer.
523 ///
524 /// `cache_k` / `cache_v`: `[nkv, capacity, hd]` (head-major, pre-allocated)
525 /// `new_k_head_major` / `new_v_head_major`: `[nkv, new_tokens, hd]`
526 /// — produced directly by `qk_norm_rope`, no extra transpose needed.
527 ///
528 /// In-place append at slot `[nkv, cache_len..cache_len+new_tokens, hd]`.
529 /// Caller owns `cache_len` bookkeeping.
530 #[allow(clippy::too_many_arguments)]
531 fn kv_cache_append_head_major(
532 ctx: &mut Self::Context,
533 cache_k: &mut Self::Buffer,
534 cache_v: &mut Self::Buffer,
535 cache_len: usize,
536 cache_capacity: usize,
537 new_k_head_major: &Self::Buffer,
538 new_v_head_major: &Self::Buffer,
539 new_tokens: usize,
540 nkv: usize,
541 hd: usize,
542 );
543
544 /// Transpose [heads, tokens, dim] → [tokens, heads, dim].
545 /// Called after `flash_attention` to restore token-major layout for O-proj.
546 fn transpose_head_to_token(
547 ctx: &mut Self::Context,
548 src: &Self::Buffer,
549 dst: &mut Self::Buffer,
550 tokens: usize,
551 heads: usize,
552 dim: usize,
553 );
554
555 /// Inverse of `transpose_head_to_token`: [tokens, heads, dim] →
556 /// [heads, tokens, dim]. Used by the CUDA `paged_decode_attention`
557 /// wrapper to convert `paged_varlen_attention`'s token-major output
558 /// back to the head-major layout that Qwen3MoeModel expects.
559 /// Default panics — backends without a paged-KV CUDA path don't
560 /// hit this code.
561 fn transpose_token_to_head(
562 _ctx: &mut Self::Context,
563 _src: &Self::Buffer,
564 _dst: &mut Self::Buffer,
565 _tokens: usize,
566 _heads: usize,
567 _dim: usize,
568 ) {
569 panic!("transpose_token_to_head not implemented for this backend");
570 }
571
572 /// residual[i] += x[i] (in-place)
573 fn add_inplace(
574 ctx: &mut Self::Context,
575 residual: &mut Self::Buffer,
576 x: &Self::Buffer,
577 len: usize,
578 );
579
580 /// `dst[i] += scale * src[i]` — scalar-broadcast scaled add, in place.
581 ///
582 /// MoE per-token combine writes `out[b] += weight_k * expert_k(x[b])`
583 /// for each top-K expert; this primitive is the per-call accumulate.
584 /// Backends without a dedicated kernel can fall back to the default
585 /// implementation, which round-trips through host memory — correct,
586 /// but slow on a hot path. Override on any backend you actually
587 /// dispatch MoE on.
588 fn scaled_add_inplace(
589 _ctx: &mut Self::Context,
590 dst: &mut Self::Buffer,
591 src: &Self::Buffer,
592 scale: f32,
593 len: usize,
594 ) {
595 let mut dst_v = Self::to_vec(dst, len);
596 let src_v = Self::to_vec(src, len);
597 for i in 0..len {
598 dst_v[i] += scale * src_v[i];
599 }
600 // Move the new buffer into the slot pointed to by `dst`. Safe
601 // because `Self::Buffer: Send + Sync` and the old buffer is
602 // dropped here when overwritten.
603 *dst = Self::from_slice(&dst_v);
604 }
605
606 /// Strided variant of [`Backend::fused_silu_mul_split`] for the
607 /// bucketed MoE path: reads `gate_up` rows starting at
608 /// `in_row_offset`, writes `out` rows starting at `out_row_offset`.
609 #[allow(clippy::too_many_arguments)]
610 fn fused_silu_mul_split_strided(
611 _ctx: &mut Self::Context,
612 _gate_up: &Self::Buffer,
613 _in_row_offset: usize,
614 _out: &mut Self::Buffer,
615 _out_row_offset: usize,
616 _tokens: usize,
617 _intermediate: usize,
618 ) {
619 unimplemented!("fused_silu_mul_split_strided default impl missing");
620 }
621
622 /// Broadcast bias add: `data[r, c] += bias[c]` for every row.
623 /// Required by Bert / Clip / Whisper whose linear projections carry a bias.
624 fn add_bias(
625 ctx: &mut Self::Context,
626 data: &mut Self::Buffer,
627 bias: &Self::Buffer,
628 rows: usize,
629 cols: usize,
630 );
631
632 /// Full LayerNorm (mean + variance normalisation + affine), distinct from
633 /// the `rms_norm` used by Llama-family decoders.
634 /// `out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]`
635 /// Where `mean` and `var` are reduced over the last dim (cols).
636 #[allow(clippy::too_many_arguments)]
637 fn layer_norm(
638 ctx: &mut Self::Context,
639 x: &Self::Buffer,
640 gamma: &Self::Buffer,
641 beta: &Self::Buffer,
642 eps: f32,
643 out: &mut Self::Buffer,
644 tokens: usize,
645 dim: usize,
646 );
647
648 /// Element-wise GELU activation (erf-based, matches PyTorch default).
649 fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
650
651 // ── Buffer management (context-free) ────────────────────────────────
652
653 fn alloc(len: usize) -> Self::Buffer;
654 fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
655 fn from_slice(data: &[f32]) -> Self::Buffer;
656
657 /// Greedy-decode fast path: GPU argmax over each row of a
658 /// `[m, n]` FP16 logits buffer, returning the m token indices on the
659 /// host. Saves `m × n × 2` bytes of D2H per call (e.g. 19.5 MB at
660 /// c=32, vocab=152064) and the host-side argmax scan (~150 µs × m).
661 ///
662 /// Default impl falls back to the slow path: full `to_vec` + host
663 /// argmax. CUDA overrides with a native kernel + tiny D2H (m × 4 B).
664 /// Backends that don't override pay the same cost as
665 /// `to_vec` + host argmax, so callers can call this unconditionally.
666 fn argmax_rows_f16(
667 _ctx: &mut Self::Context,
668 logits: &Self::Buffer,
669 m: usize,
670 n: usize,
671 ) -> Result<Vec<u32>> {
672 let host = Self::to_vec(logits, m * n);
673 let mut out = Vec::with_capacity(m);
674 for row in 0..m {
675 let slice = &host[row * n..(row + 1) * n];
676 let mut max_idx = 0usize;
677 let mut max_val = f32::NEG_INFINITY;
678 for (i, &v) in slice.iter().enumerate() {
679 if v > max_val {
680 max_val = v;
681 max_idx = i;
682 }
683 }
684 out.push(max_idx as u32);
685 }
686 Ok(out)
687 }
688
689 /// Load a weight tensor straight from its on-disk byte representation,
690 /// letting the backend pick its preferred storage dtype.
691 ///
692 /// Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching
693 /// pre-existing loader behaviour. Backends override this to go straight
694 /// from raw bytes into a native half-precision buffer (e.g. Metal with
695 /// `FERRUM_METAL_DTYPE=f16`), avoiding the transient 2× RAM spike.
696 fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
697 let data = src_dtype.to_f32_vec(raw);
698 Self::from_slice(&data)
699 }
700
701 // (The Phase A3 unified `gemm_quant(QuantWeights, QuantKind)` stub
702 // that used to live here is superseded by the `load_quant` /
703 // `gemm_quant(QuantStore)` pair earlier in this trait — same idea,
704 // but the store hides the per-kind buffer layout so callers don't
705 // have to construct a per-kind `QuantWeights<'_, Self>` packet.)
706}
707
708// ════════════════════════════════════════════════════════════════════════
709// BackendPagedKv capability (vLLM-style paged KV cache + paged attention)
710// ════════════════════════════════════════════════════════════════════════
711//
712// Paged KV pool with block-table indirection, plus the paged attention
713// kernel variants that read through that indirection. CUDA + Metal both
714// implement the real kernels; CPU `impl BackendPagedKv for CpuBackend {}`
715// inherits unsupported defaults.
716
717/// Capability-trait for backends that support paged KV cache + paged attention.
718pub trait BackendPagedKv: Backend {
719 /// Whether this backend has a paged-KV decode path
720 /// (`paged_decode_attention` etc.). Currently true for Metal, false
721 /// for CPU. Used to decide the default of `FERRUM_METAL_PAGED_KV` —
722 /// the `serve` path should opt in automatically when supported so
723 /// users get the bench-quality concurrent-decode numbers without
724 /// having to learn the flag.
725 fn supports_paged_kv() -> bool {
726 false
727 }
728 /// Pre-populate the per-slot device-pointer scratch arrays used by
729 /// the batched kernels (`kv_cache_append_batched_per_cache` and
730 /// `flash_attention_batched_per_cache`). Required by the CUDA-graph
731 /// capture path: the captured graph contains only kernel launches
732 /// (no captured `memcpy_htod`), so the device scratch must be fresh
733 /// when the graph replays.
734 ///
735 /// Caller passes flat layer-major slices: `k_caches[li * m + i]` and
736 /// `v_caches[li * m + i]`. Backend extracts each cache's device
737 /// pointer and writes into its corresponding slot in the device
738 /// scratch via SYNCHRONOUS memcpy (not captured by stream capture).
739 ///
740 /// CUDA-only; other backends fall through to the default
741 /// `unsupported` and the caller skips the population call.
742 fn populate_batched_pointers(
743 _ctx: &mut Self::Context,
744 _k_caches: &[&Self::Buffer],
745 _v_caches: &[&Self::Buffer],
746 _num_layers: usize,
747 _m: usize,
748 ) -> Result<()> {
749 Err(FerrumError::unsupported(
750 "populate_batched_pointers not implemented for this backend",
751 ))
752 }
753 /// Paged-KV variant of [`Self::split_qkv_norm_rope_into_cache`].
754 ///
755 /// Same fused split + qk-norm + RoPE, but K/V are written into a
756 /// paged pool `[num_blocks, kv_heads, block_size, head_dim]`
757 /// indexed via `block_table[logical_block]` → physical_block.
758 /// Q still goes to head-major scratch.
759 ///
760 /// Default returns Unsupported. Backends that lack a paged kernel
761 /// keep using the contiguous variant.
762 /// `qkv_byte_offset` / `q_out_byte_offset` let the caller pass a
763 /// slice of a larger batched buffer (used by the multi-seq paged
764 /// path in `decode_batch_internal`). For single-seq dispatch they
765 /// should be 0.
766 #[allow(clippy::too_many_arguments)]
767 fn split_qkv_norm_rope_into_paged_cache(
768 _ctx: &mut Self::Context,
769 _qkv: &Self::Buffer,
770 _qkv_byte_offset: u64,
771 _q_norm_w: &Self::Buffer,
772 _k_norm_w: &Self::Buffer,
773 _cos: &Self::Buffer,
774 _sin: &Self::Buffer,
775 _q_out: &mut Self::Buffer,
776 _q_out_byte_offset: u64,
777 _cache_k: &mut Self::Buffer,
778 _cache_v: &mut Self::Buffer,
779 _block_table: &Self::Buffer,
780 _tokens: usize,
781 _q_heads: usize,
782 _kv_heads: usize,
783 _head_dim: usize,
784 _pos_offset: usize,
785 _eps: f32,
786 _qk_mode: i32,
787 _cache_len: usize,
788 _block_size: usize,
789 _max_num_blocks_per_seq: usize,
790 ) -> Result<()> {
791 Err(FerrumError::unsupported(
792 "split_qkv_norm_rope_into_paged_cache not implemented for this backend",
793 ))
794 }
795 /// Paged-KV variant of [`Self::flash_attention`].
796 ///
797 /// Decode (`q_len == 1`):
798 /// `q`/`out`: `[num_seqs, num_heads, head_dim]` (token-major)
799 ///
800 /// Causal prefill (`q_len > 1`, single seq):
801 /// `q`/`out`: `[num_heads, q_len, head_dim]` (head-major — the
802 /// layout produced by `split_qkv_norm_rope_into_paged_cache`)
803 /// The kernel applies a per-q-token causal mask using
804 /// `context_lens[seq]` as the FINAL kv_len (= `pos_offset + q_len`):
805 /// token i sees positions `[0, context_lens - q_len + 1 + i)`.
806 ///
807 /// Common to both:
808 /// `k_pool`/`v_pool`: `[num_blocks, num_kv_heads, block_size, head_dim]`
809 /// `block_tables`: `[num_seqs, max_num_blocks_per_seq]` u32
810 /// `context_lens`: `[num_seqs]` u32
811 ///
812 /// Backends without a paged kernel return Unsupported; callers are
813 /// expected to fall back to contiguous KV.
814 #[allow(clippy::too_many_arguments)]
815 fn paged_decode_attention(
816 _ctx: &mut Self::Context,
817 _q: &Self::Buffer,
818 _k_pool: &Self::Buffer,
819 _v_pool: &Self::Buffer,
820 _out: &mut Self::Buffer,
821 _block_tables: &Self::Buffer,
822 _context_lens: &Self::Buffer,
823 _num_seqs: usize,
824 _num_heads: usize,
825 _num_kv_heads: usize,
826 _head_dim: usize,
827 _block_size: usize,
828 _max_num_blocks_per_seq: usize,
829 _q_len: usize,
830 ) -> Result<()> {
831 Err(FerrumError::unsupported(
832 "paged_decode_attention not implemented for this backend",
833 ))
834 }
835 /// Capability: does this backend implement
836 /// `split_qkv_norm_rope_into_paged_cache_varlen` and
837 /// `paged_varlen_attention`? Required by the unified mixed-batch
838 /// forward path used by `LlamaFamilyModel::unified_forward`. Default
839 /// false; backends that ship the varlen kernels override.
840 fn supports_varlen_qkv() -> bool {
841 false
842 }
843 /// Varlen variant of [`Self::split_qkv_norm_rope_into_paged_cache`].
844 ///
845 /// Single launch covering ALL sequences in the batch. Reads
846 /// `pos_offsets[seq]`, `cu_seqlens_q[seq]`, and the per-seq
847 /// block_table from device buffers — graph-capturable (the per-iter
848 /// state is in buffers, not kernel scalars). Replaces the per-item
849 /// dispatch loop in `unified_forward_layer` with one call.
850 ///
851 /// Layouts:
852 /// - `qkv`: `[m_total, q_dim + 2 * kv_dim]` token-major
853 /// - `q_out`: `[m_total, q_heads, head_dim]` token-major (matches
854 /// what `paged_varlen_attention` reads)
855 /// - `cache_k` / `cache_v`: paged pool same as `paged_varlen_attention`
856 /// - `cu_seqlens_q`: `[num_seqs + 1]` u32 prefix sum
857 /// - `pos_offsets`: `[num_seqs]` u32, starting kv_pos per seq
858 /// - `block_tables`: `[num_seqs, max_blocks_per_seq]` i32 stacked
859 #[allow(clippy::too_many_arguments)]
860 fn split_qkv_norm_rope_into_paged_cache_varlen(
861 _ctx: &mut Self::Context,
862 _qkv: &Self::Buffer,
863 _q_norm_w: &Self::Buffer,
864 _k_norm_w: &Self::Buffer,
865 _cos: &Self::Buffer,
866 _sin: &Self::Buffer,
867 _q_out: &mut Self::Buffer,
868 _cache_k: &mut Self::Buffer,
869 _cache_v: &mut Self::Buffer,
870 _cu_seqlens_q: &Self::Buffer,
871 _pos_offsets: &Self::Buffer,
872 _block_tables: &Self::Buffer,
873 _num_seqs: usize,
874 _m_total: usize,
875 _q_heads: usize,
876 _kv_heads: usize,
877 _head_dim: usize,
878 _eps: f32,
879 _qk_mode: i32,
880 _block_size: usize,
881 _max_blocks_per_seq: usize,
882 ) -> Result<()> {
883 Err(FerrumError::unsupported(
884 "split_qkv_norm_rope_into_paged_cache_varlen not implemented for this backend",
885 ))
886 }
887 /// Variable-length paged attention with GQA + causal mask.
888 ///
889 /// Supports a unified mixed batch where each sequence contributes
890 /// 1 (decode) or N (prefill chunk) query tokens — the workhorse for
891 /// chunked-prefill. See `kernels/paged_varlen_attention.cu` for the
892 /// kernel itself.
893 ///
894 /// Layouts:
895 /// - `q` / `out`: `[total_q_tokens, num_heads, head_dim]` (token-
896 /// major, FP16). `total_q_tokens` = `cu_seqlens_q[num_seqs]`.
897 /// - `k_pool` / `v_pool`: paged block pool, layout matches
898 /// `paged_decode_attention`.
899 /// - `cu_seqlens_q`: `[num_seqs + 1]` u32 prefix sum, with
900 /// `cu_seqlens_q[0] = 0` and `cu_seqlens_q[num_seqs] = total_q_tokens`.
901 /// - `pos_offsets`: `[num_seqs]` u32, the starting absolute KV
902 /// position of each seq's first q token (= prior `kv_len`).
903 /// - `block_tables`: `[num_seqs, max_num_blocks_per_seq]` i32 grid.
904 ///
905 /// Each query token attends causally to all KV positions
906 /// `[0, pos_offsets[s] + local_idx]`.
907 #[allow(clippy::too_many_arguments)]
908 fn paged_varlen_attention(
909 _ctx: &mut Self::Context,
910 _q: &Self::Buffer,
911 _k_pool: &Self::Buffer,
912 _v_pool: &Self::Buffer,
913 _out: &mut Self::Buffer,
914 _cu_seqlens_q: &Self::Buffer,
915 _pos_offsets: &Self::Buffer,
916 _block_tables: &Self::Buffer,
917 _num_seqs: usize,
918 _total_q_tokens: usize,
919 _max_kv_len: usize,
920 _num_heads: usize,
921 _num_kv_heads: usize,
922 _head_dim: usize,
923 _block_size: usize,
924 _max_num_blocks_per_seq: usize,
925 ) -> Result<()> {
926 Err(FerrumError::unsupported(
927 "paged_varlen_attention not implemented for this backend",
928 ))
929 }
930
931 /// Opt-in vLLM FlashAttention-2 FFI path for FA-layout paged KV.
932 ///
933 /// This is intentionally separate from [`Self::paged_varlen_attention`]:
934 /// it needs the final per-sequence KV lengths (`seq_lens`) and an explicit
935 /// LSE scratch buffer because the external FA2 runner writes softmax LSE.
936 /// Default returns Err(unsupported); CUDA overrides when a runtime shim is
937 /// provided via `FERRUM_FA2_DIRECT_FFI_SHIM`.
938 #[allow(clippy::too_many_arguments)]
939 fn paged_varlen_attention_fa2_ffi(
940 _ctx: &mut Self::Context,
941 _q: &Self::Buffer,
942 _k_pool: &Self::Buffer,
943 _v_pool: &Self::Buffer,
944 _out: &mut Self::Buffer,
945 _lse: &mut Self::Buffer,
946 _cu_seqlens_q: &Self::Buffer,
947 _seq_lens: &Self::Buffer,
948 _block_tables: &Self::Buffer,
949 _num_seqs: usize,
950 _total_q_tokens: usize,
951 _max_q_len: usize,
952 _max_kv_len: usize,
953 _num_heads: usize,
954 _num_kv_heads: usize,
955 _head_dim: usize,
956 _block_size: usize,
957 _max_num_blocks_per_seq: usize,
958 ) -> Result<()> {
959 Err(FerrumError::unsupported(
960 "paged_varlen_attention_fa2_ffi not implemented for this backend",
961 ))
962 }
963
964 /// Batched paged decode attention — multi-seq, single token per seq.
965 /// Faster path for the unified_forward layer when m_total == num_seqs
966 /// (every item is a single-token decode). Skips the cu_seqlens_q
967 /// linear scan that `paged_varlen_attention` does in the fully-mixed
968 /// case.
969 ///
970 /// Layouts:
971 /// q : [num_seqs, num_q_heads, head_dim]
972 /// k_pool/v_pool : paged pool (same as paged_varlen)
973 /// block_tables : [num_seqs, max_num_blocks_per_seq]
974 /// valid_kv_lens : [num_seqs] — current kv_len per seq
975 /// out : [num_seqs, num_q_heads, head_dim]
976 ///
977 /// Default returns Err(unsupported); CUDA backend overrides.
978 #[allow(clippy::too_many_arguments)]
979 fn paged_batched_decode_attention(
980 _ctx: &mut Self::Context,
981 _q: &Self::Buffer,
982 _k_pool: &Self::Buffer,
983 _v_pool: &Self::Buffer,
984 _out: &mut Self::Buffer,
985 _block_tables: &Self::Buffer,
986 _valid_kv_lens: &Self::Buffer,
987 _num_seqs: usize,
988 _max_kv_len: usize,
989 _num_heads: usize,
990 _num_kv_heads: usize,
991 _head_dim: usize,
992 _block_size: usize,
993 _max_num_blocks_per_seq: usize,
994 ) -> Result<()> {
995 Err(FerrumError::unsupported(
996 "paged_batched_decode_attention not implemented for this backend",
997 ))
998 }
999
1000 /// Capability: backend has vLLM-layout paged KV write kernels and the
1001 /// `paged_attention_v2` decode kernel. Models that opt into this layout
1002 /// at construction time (via `FERRUM_USE_VLLM_PAGED_ATTN=1`) must
1003 /// dispatch ALL paged writes and reads through the `_vllm` variants —
1004 /// the layouts are not compatible. Default `false`.
1005 fn supports_vllm_paged_attn() -> bool {
1006 false
1007 }
1008
1009 /// vLLM-layout variant of
1010 /// [`Self::split_qkv_norm_rope_into_paged_cache`]. K/V are written in
1011 /// vLLM's `paged_attention_v2` layout: K is
1012 /// `[num_blocks, kv_heads, head_dim/x, block_size, x]` (x = 16/sizeof(elem)),
1013 /// V is `[num_blocks, kv_heads, head_dim, block_size]`. Q output and
1014 /// every other argument matches the non-vllm variant exactly so the
1015 /// model layer can swap dispatchers based on a single flag.
1016 #[allow(clippy::too_many_arguments)]
1017 fn split_qkv_norm_rope_into_paged_cache_vllm(
1018 _ctx: &mut Self::Context,
1019 _qkv: &Self::Buffer,
1020 _qkv_byte_offset: u64,
1021 _q_norm_w: &Self::Buffer,
1022 _k_norm_w: &Self::Buffer,
1023 _cos: &Self::Buffer,
1024 _sin: &Self::Buffer,
1025 _q_out: &mut Self::Buffer,
1026 _q_out_byte_offset: u64,
1027 _cache_k: &mut Self::Buffer,
1028 _cache_v: &mut Self::Buffer,
1029 _block_table: &Self::Buffer,
1030 _tokens: usize,
1031 _q_heads: usize,
1032 _kv_heads: usize,
1033 _head_dim: usize,
1034 _pos_offset: usize,
1035 _eps: f32,
1036 _qk_mode: i32,
1037 _cache_len: usize,
1038 _block_size: usize,
1039 _max_num_blocks_per_seq: usize,
1040 ) -> Result<()> {
1041 Err(FerrumError::unsupported(
1042 "split_qkv_norm_rope_into_paged_cache_vllm not implemented for this backend",
1043 ))
1044 }
1045
1046 /// vLLM-layout variant of
1047 /// [`Self::split_qkv_norm_rope_into_paged_cache_varlen`]. Same signature
1048 /// — only the K/V cache layout changes.
1049 #[allow(clippy::too_many_arguments)]
1050 fn split_qkv_norm_rope_into_paged_cache_varlen_vllm(
1051 _ctx: &mut Self::Context,
1052 _qkv: &Self::Buffer,
1053 _q_norm_w: &Self::Buffer,
1054 _k_norm_w: &Self::Buffer,
1055 _cos: &Self::Buffer,
1056 _sin: &Self::Buffer,
1057 _q_out: &mut Self::Buffer,
1058 _cache_k: &mut Self::Buffer,
1059 _cache_v: &mut Self::Buffer,
1060 _cu_seqlens_q: &Self::Buffer,
1061 _pos_offsets: &Self::Buffer,
1062 _block_tables: &Self::Buffer,
1063 _num_seqs: usize,
1064 _m_total: usize,
1065 _q_heads: usize,
1066 _kv_heads: usize,
1067 _head_dim: usize,
1068 _eps: f32,
1069 _qk_mode: i32,
1070 _block_size: usize,
1071 _max_blocks_per_seq: usize,
1072 ) -> Result<()> {
1073 Err(FerrumError::unsupported(
1074 "split_qkv_norm_rope_into_paged_cache_varlen_vllm not implemented for this backend",
1075 ))
1076 }
1077
1078 /// vLLM `paged_attention_v2` — multi-partition split-K decode attention
1079 /// reading the vLLM K/V layout. `q_len` is implicitly 1 (decode only;
1080 /// vLLM's v2 kernel does not support q_len > 1). `max_seq_len` is the
1081 /// max kv_len across the batch — used to size the partition reduction.
1082 #[allow(clippy::too_many_arguments)]
1083 fn paged_decode_attention_v2(
1084 _ctx: &mut Self::Context,
1085 _q: &Self::Buffer,
1086 _k_pool: &Self::Buffer,
1087 _v_pool: &Self::Buffer,
1088 _out: &mut Self::Buffer,
1089 _block_tables: &Self::Buffer,
1090 _context_lens: &Self::Buffer,
1091 _num_seqs: usize,
1092 _num_heads: usize,
1093 _num_kv_heads: usize,
1094 _head_dim: usize,
1095 _block_size: usize,
1096 _max_num_blocks_per_seq: usize,
1097 _max_seq_len: usize,
1098 ) -> Result<()> {
1099 Err(FerrumError::unsupported(
1100 "paged_decode_attention_v2 not implemented for this backend",
1101 ))
1102 }
1103
1104 /// q_len>1 prefill/chunk-prefill attention over vLLM-layout paged KV.
1105 /// This keeps cache layout consistent when `FERRUM_USE_VLLM_PAGED_ATTN=1`
1106 /// and the prompt path writes K/V in the layout consumed later by
1107 /// `paged_decode_attention_v2`.
1108 #[allow(clippy::too_many_arguments)]
1109 fn paged_varlen_attention_vllm_layout(
1110 _ctx: &mut Self::Context,
1111 _q: &Self::Buffer,
1112 _k_pool: &Self::Buffer,
1113 _v_pool: &Self::Buffer,
1114 _out: &mut Self::Buffer,
1115 _block_tables: &Self::Buffer,
1116 _context_lens: &Self::Buffer,
1117 _num_seqs: usize,
1118 _num_heads: usize,
1119 _num_kv_heads: usize,
1120 _head_dim: usize,
1121 _block_size: usize,
1122 _max_num_blocks_per_seq: usize,
1123 _q_len: usize,
1124 ) -> Result<()> {
1125 Err(FerrumError::unsupported(
1126 "paged_varlen_attention_vllm_layout not implemented for this backend",
1127 ))
1128 }
1129
1130 /// Variable-length paged attention over vLLM-layout paged KV.
1131 ///
1132 /// Unlike [`Self::paged_varlen_attention_vllm_layout`], this accepts the
1133 /// same varlen index tensors as [`Self::paged_varlen_attention`] and writes
1134 /// token-major output directly. It is the unified mixed-batch companion for
1135 /// `split_qkv_norm_rope_into_paged_cache_varlen_vllm`.
1136 #[allow(clippy::too_many_arguments)]
1137 fn paged_varlen_attention_vllm(
1138 _ctx: &mut Self::Context,
1139 _q: &Self::Buffer,
1140 _k_pool: &Self::Buffer,
1141 _v_pool: &Self::Buffer,
1142 _out: &mut Self::Buffer,
1143 _cu_seqlens_q: &Self::Buffer,
1144 _pos_offsets: &Self::Buffer,
1145 _block_tables: &Self::Buffer,
1146 _num_seqs: usize,
1147 _total_q_tokens: usize,
1148 _max_kv_len: usize,
1149 _num_heads: usize,
1150 _num_kv_heads: usize,
1151 _head_dim: usize,
1152 _block_size: usize,
1153 _max_num_blocks_per_seq: usize,
1154 ) -> Result<()> {
1155 Err(FerrumError::unsupported(
1156 "paged_varlen_attention_vllm not implemented for this backend",
1157 ))
1158 }
1159
1160 /// Q-tiled vLLM-layout varlen attention. `tile_seqs` and `tile_starts`
1161 /// describe a compact list of q-token tiles, avoiding empty grid blocks
1162 /// for mixed batches that contain both long prefill items and q_len=1
1163 /// decode items. Semantics match [`Self::paged_varlen_attention_vllm`].
1164 #[allow(clippy::too_many_arguments)]
1165 fn paged_varlen_attention_vllm_tiled_q4(
1166 _ctx: &mut Self::Context,
1167 _q: &Self::Buffer,
1168 _k_pool: &Self::Buffer,
1169 _v_pool: &Self::Buffer,
1170 _out: &mut Self::Buffer,
1171 _cu_seqlens_q: &Self::Buffer,
1172 _pos_offsets: &Self::Buffer,
1173 _block_tables: &Self::Buffer,
1174 _tile_seqs: &Self::Buffer,
1175 _tile_starts: &Self::Buffer,
1176 _num_tiles: usize,
1177 _max_kv_len: usize,
1178 _num_heads: usize,
1179 _num_kv_heads: usize,
1180 _head_dim: usize,
1181 _block_size: usize,
1182 _max_num_blocks_per_seq: usize,
1183 ) -> Result<()> {
1184 Err(FerrumError::unsupported(
1185 "paged_varlen_attention_vllm_tiled_q4 not implemented for this backend",
1186 ))
1187 }
1188}
1189
1190// ════════════════════════════════════════════════════════════════════════
1191// Capability bundles — readable type aliases over the supertrait set
1192// ════════════════════════════════════════════════════════════════════════
1193//
1194// Models declare what they need via these bundles instead of spelling out
1195// every supertrait. Rust auto-derives the impl via blanket impls below,
1196// so any backend that satisfies the underlying supertraits automatically
1197// becomes a `LlmBackend` / `QuantLlmBackend` / `MoeLlmBackend`.
1198
1199/// Minimum capability set for a decoder-only LLM: the core compute trait
1200/// plus paged-KV cache + graph-capture support. Every concrete backend
1201/// (CUDA / Metal / CPU) satisfies this.
1202pub trait LlmBackend: Backend + BackendGraph + BackendPagedKv {}
1203impl<T> LlmBackend for T where T: Backend + BackendGraph + BackendPagedKv {}
1204
1205/// LLM backend that also supports quantized weight loading (GPTQ Marlin
1206/// for CUDA; GGUF k-quant for Metal). Required by models that hold
1207/// `Box<dyn Linear<B>>` where the Linear impl might be a quant variant.
1208pub trait QuantLlmBackend: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
1209impl<T> QuantLlmBackend for T where T: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
1210
1211/// MoE-capable LLM backend: adds the fused MoE routing + post-op kernels
1212/// to the quant LLM bundle. Required by Qwen3-MoE / future MoE models.
1213pub trait MoeLlmBackend: QuantLlmBackend + BackendMoeFused {}
1214impl<T> MoeLlmBackend for T where T: QuantLlmBackend + BackendMoeFused {}
1215
1216// ════════════════════════════════════════════════════════════════════════
1217// KV cache dtype axis (dim 5 of the 5-dimension architecture)
1218// ════════════════════════════════════════════════════════════════════════
1219//
1220// Each model's KV cache has its own precision independent of the model's
1221// compute precision. vLLM 0.6+ ships INT8 / FP8 KV caches that halve KV
1222// memory at small (<1%) accuracy hit. Today ferrum's KV is hardcoded
1223// FP16 on CUDA / Metal — to support INT8/FP8 KV in a future PR, the
1224// type system needs an explicit axis.
1225//
1226// Phase 4 scope: scaffolding only. All concrete backends impl
1227// `BackendKvDtype<KvFp16>` so existing models keep working unchanged.
1228// Future PR: implement BackendKvDtype<KvInt8> on CUDA + a new model
1229// type-parameter `K: KvDtypeKind` to wire it through.
1230
1231// `KvDtypeKind` + `KvFp16` / `KvBf16` / `KvInt8` / `KvFp8` markers moved
1232// to `ferrum_interfaces::kv_dtype` (no GPU deps, so the right place is
1233// the contract crate). Re-exported here so existing callers keep
1234// compiling against `crate::backend::KvFp16` etc.
1235pub use ferrum_interfaces::kv_dtype::{KvBf16, KvDtypeKind, KvFp16, KvFp8, KvInt8};
1236
1237/// Capability-trait for backends that can store + read a KV cache of
1238/// type `K`.
1239///
1240/// The two associated types carry the K-specific storage shape:
1241/// - `KvBuffer`: per-layer K/V element storage. For `K = KvFp16` it
1242/// is the backend's normal `Self::Buffer` (FP16). For `K = KvInt8`
1243/// it is the backend's INT8 buffer (e.g. `CudaSlice<i8>` on CUDA).
1244/// - `KvScales`: per-token-per-kv-head scales. For `K = KvFp16` this
1245/// is the unit type `()` (no scales). For `K = KvInt8` / `KvFp8`
1246/// it is a backend-specific FP16 buffer.
1247///
1248/// Models that want INT8 KV use:
1249/// `where B: BackendKvDtype<KvInt8>`
1250/// — the buffers in `KvCache<B, KvInt8>` are then `CudaSlice<i8>` and
1251/// `CudaSlice<f16>`, distinct from the FP16 path's `Self::Buffer`.
1252pub trait BackendKvDtype<K: KvDtypeKind>: BackendPagedKv {
1253 /// Per-layer K/V element storage.
1254 type KvBuffer: Send + Sync;
1255 /// Per-token per-kv-head scale storage. `()` for FP16 (no scales).
1256 type KvScales: Send + Sync + Default;
1257}
1258
1259/// INT8 KV cache operations (Dim 5).
1260///
1261/// `BackendKvDtype<KvInt8>` only declares the storage types; it does not
1262/// know how to write INT8 K/V into a paged pool or run paged decode
1263/// attention against an INT8 cache. Those launchers live here so the
1264/// model layer can call them through a single `B: BackendInt8KvOps` bound
1265/// without dropping into backend-specific code.
1266///
1267/// Today only `CudaBackend` provides a real implementation (delegating to
1268/// [`crate::int8_kv::launch_int8_kv_cache_append`] and
1269/// [`crate::int8_kv::launch_int8_paged_decode_attention`]). Other backends
1270/// inherit the default `unimplemented!()` body — the registry factory
1271/// rejects `(Device::CPU/Metal, KvCacheDtype::Int8)` before the model
1272/// gets a chance to call into these.
1273#[allow(clippy::too_many_arguments)]
1274pub trait BackendInt8KvOps: Backend + BackendKvDtype<KvInt8> {
1275 /// Allocate the per-layer INT8 paged cache for one sequence.
1276 /// Default panics — backends without INT8 support never reach this
1277 /// path (factory rejects (Cpu/Metal, Int8) before ensure_kv runs).
1278 fn alloc_paged_int8_layer(
1279 _max_blocks_per_seq: usize,
1280 _block_size: usize,
1281 _num_kv_heads: usize,
1282 _head_dim: usize,
1283 ) -> KvCacheQuant<Self, KvInt8> {
1284 unimplemented!("alloc_paged_int8_layer not supported on this backend")
1285 }
1286
1287 /// Append `tokens` FP16 K/V values into the paged INT8 pool.
1288 /// `paged_block_indices` is the host-side mirror of the per-seq
1289 /// logical→physical block table (already populated at `ensure_kv` time
1290 /// — see `KvCacheQuant::paged_block_indices`). Passing the host slice
1291 /// avoids a per-token D2H + sync barrier; backend computes the slot
1292 /// mapping host-side, async-H2D's it, and chains the append kernel
1293 /// on the same stream — fully overlapping with prior work.
1294 /// `cache_len_before` is the current number of valid tokens; the
1295 /// backend quantizes FP16 → INT8 with per-(token, kv-head) FP16 scale
1296 /// and writes both into the layer's INT8 / scale buffers.
1297 fn int8_kv_append_paged(
1298 _ctx: &mut Self::Context,
1299 _k_in: &Self::Buffer,
1300 _v_in: &Self::Buffer,
1301 _layer_k: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
1302 _layer_v: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
1303 _layer_k_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
1304 _layer_v_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
1305 _paged_block_indices: &[u32],
1306 _cache_len_before: usize,
1307 _tokens: usize,
1308 _block_size: usize,
1309 _num_kv_heads: usize,
1310 _head_dim: usize,
1311 ) -> Result<()> {
1312 Err(FerrumError::unsupported(
1313 "int8_kv_append_paged not implemented for this backend",
1314 ))
1315 }
1316
1317 /// Run paged decode attention reading from an INT8 cache. Q is FP16,
1318 /// output is FP16; the kernel dequantizes K/V on the fly using the
1319 /// per-token scales. `valid_kv_len` is the post-append cache length
1320 /// (i.e. the kernel attends over `[0, valid_kv_len)` tokens).
1321 fn int8_paged_decode_attention(
1322 _ctx: &mut Self::Context,
1323 _q: &Self::Buffer,
1324 _layer_k: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
1325 _layer_v: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
1326 _layer_k_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
1327 _layer_v_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
1328 _block_table: &Self::Buffer,
1329 _output: &mut Self::Buffer,
1330 _num_q_heads: usize,
1331 _num_kv_heads: usize,
1332 _head_dim: usize,
1333 _valid_kv_len: usize,
1334 _block_size: usize,
1335 _scale: f32,
1336 ) -> Result<()> {
1337 Err(FerrumError::unsupported(
1338 "int8_paged_decode_attention not implemented for this backend",
1339 ))
1340 }
1341}
1342
1343// Cpu/Metal NOT impl `BackendInt8KvOps` — the trait pivot to
1344// `KvLayer<B>` means `KvInt8: KvLayer<B>` only holds where
1345// `B: BackendInt8KvOps`, so `LlamaFamilyModel<CpuBackend, KvInt8>` is a
1346// compile error (no INT8 KvLayer impl satisfies it). Type system
1347// enforces the constraint without runtime stubs.