ferrum_kernels/backend/traits.rs
1//! Core Backend trait — the single abstraction over CUDA / Metal / CPU.
2
3use ferrum_types::{FerrumError, Result};
4use half::{bf16, f16};
5
6/// Source dtype for a weight tensor read straight from safetensors mmap.
7///
8/// Passed to `Backend::from_weight_bytes` so each backend can choose whether
9/// to upcast to its compute dtype or store as-is.
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum SrcDtype {
12 F32,
13 F16,
14 BF16,
15}
16
17impl SrcDtype {
18 /// Number of bytes per element in the raw on-disk representation.
19 pub const fn bytes_per_elem(self) -> usize {
20 match self {
21 SrcDtype::F32 => 4,
22 SrcDtype::F16 | SrcDtype::BF16 => 2,
23 }
24 }
25
26 /// Materialise the raw byte slice into a `Vec<f32>`. Used by the default
27 /// `Backend::from_weight_bytes` impl; fp16-preferring backends bypass it.
28 pub fn to_f32_vec(self, raw: &[u8]) -> Vec<f32> {
29 match self {
30 SrcDtype::F32 => {
31 debug_assert_eq!(raw.len() % 4, 0);
32 let n = raw.len() / 4;
33 let mut out = vec![0f32; n];
34 for i in 0..n {
35 let b = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
36 out[i] = f32::from_le_bytes(b);
37 }
38 out
39 }
40 SrcDtype::F16 => {
41 debug_assert_eq!(raw.len() % 2, 0);
42 let n = raw.len() / 2;
43 let mut out = vec![0f32; n];
44 for i in 0..n {
45 out[i] = f16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
46 }
47 out
48 }
49 SrcDtype::BF16 => {
50 debug_assert_eq!(raw.len() % 2, 0);
51 let n = raw.len() / 2;
52 let mut out = vec![0f32; n];
53 for i in 0..n {
54 out[i] = bf16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
55 }
56 out
57 }
58 }
59 }
60}
61
62/// Quantization flavour discriminator for `Backend::gemm_quant`.
63///
64/// Distinct schemes need distinct kernels. Carried as a parameter so the
65/// Backend trait does not explode with one method per quantization type.
66#[derive(Clone, Debug)]
67pub enum QuantKind {
68 /// GPTQ: group-wise int4/int8 with scales + zeros (asymmetric) + optional g_idx.
69 Gptq {
70 bits: u32,
71 group_size: usize,
72 desc_act: bool,
73 },
74 /// AWQ: activation-aware int4 with scales + zeros, different packing from GPTQ.
75 Awq { bits: u32, group_size: usize },
76 /// GGUF: one of k-quants / legacy quants, fully specified by the inner type.
77 Gguf { quant_type: GgufQuantType },
78}
79
80/// GGUF quantization sub-type (expand as kernels are added).
81#[derive(Clone, Copy, Debug)]
82pub enum GgufQuantType {
83 Q4_0,
84 Q4_1,
85 Q4K,
86 Q5K,
87 Q6K,
88 Q8_0,
89}
90
91/// Packed quantized weight buffers passed to `Backend::gemm_quant`.
92///
93/// Not every field is used by every `QuantKind` — e.g. GGUF packs scales
94/// inside `qweight`, so `scales` / `zeros` may be dummies. The Backend
95/// implementation is expected to validate the shape for the kind it handles.
96pub struct QuantWeights<'a, B: Backend> {
97 pub qweight: &'a B::Buffer,
98 pub scales: Option<&'a B::Buffer>,
99 pub zeros: Option<&'a B::Buffer>,
100 pub g_idx: Option<&'a B::Buffer>,
101}
102
103/// Collective-op reduction kind for TP all_reduce.
104#[derive(Clone, Copy, Debug)]
105pub enum ReduceOp {
106 Sum,
107 Max,
108 Min,
109}
110
111/// Configuration for attention dispatch.
112#[derive(Clone, Debug)]
113pub struct AttnConfig {
114 pub num_heads: usize,
115 pub num_kv_heads: usize,
116 pub head_dim: usize,
117 pub causal: bool,
118 pub scale: f32,
119 /// Stride (in rows) between head blocks in the KV buffer.
120 /// `0` means contiguous (use `kv_len`, legacy behaviour).
121 /// Set to `cache_capacity` when flashing against a pre-allocated cache
122 /// that only has `kv_len` valid slots out of `cache_capacity`.
123 pub kv_seq_stride: usize,
124 /// Sliding-window attention size (Mistral v0.1, Gemma).
125 /// `0` = disabled (full causal attention).
126 /// `w > 0` = each query position attends to the previous `w` KV positions
127 /// (still bounded by `causal` + `pos_offset + qi + 1` as the upper end).
128 pub sliding_window: usize,
129}
130
131impl Default for AttnConfig {
132 fn default() -> Self {
133 Self {
134 num_heads: 0,
135 num_kv_heads: 0,
136 head_dim: 0,
137 causal: false,
138 scale: 1.0,
139 kv_seq_stride: 0,
140 sliding_window: 0,
141 }
142 }
143}
144
145// Note: `TransformerConfig` / `AttnType` / `MlpType` / `RopeConfig` used to
146// live here when `ModelRunner` needed a generic model config. They're now
147// per-model (e.g. `Qwen3Config` in `ferrum-models::models::qwen3`) so each
148// model can carry exactly the architecture parameters it cares about.
149// Backend trait stays model-agnostic.
150
151/// Per-layer KV cache. Each model owns its own `Vec<KvCache<B>>` per sequence.
152///
153/// Two layouts are supported, selected at allocation time:
154/// 1. **Contiguous** (default): `k`/`v` are `[num_kv_heads, capacity, head_dim]`
155/// f32 buffers. `block_size == 0` and `block_table` / `context_lens` are
156/// `None`. Original ferrum layout — used when `FERRUM_METAL_PAGED_KV` is
157/// unset.
158/// 2. **Paged** (vLLM-style): `k`/`v` are `[num_blocks, num_kv_heads,
159/// block_size, head_dim]` block pools. `block_size > 0` and
160/// `block_table` (`u32[max_num_blocks_per_seq]`) + `context_lens`
161/// (`u32[1]` single-seq for now) are populated. Multi-seq sharing
162/// is a Phase 4 concern; today every paged cache_id has its own
163/// pool but the kernel-level indirection works.
164pub struct KvCache<B: Backend> {
165 pub k: B::Buffer,
166 pub v: B::Buffer,
167 pub len: usize,
168 pub capacity: usize,
169 pub num_kv_heads: usize,
170 pub head_dim: usize,
171 /// Paged: KV positions per physical block. `0` ⇒ contiguous layout.
172 pub block_size: usize,
173 /// Paged: `[max_num_blocks_per_seq]` u32 — logical → physical block.
174 pub block_table: Option<B::Buffer>,
175 /// Paged: `[1]` u32 — current context length for the kernel to read.
176 pub context_lens: Option<B::Buffer>,
177 /// Paged: host-side mirror of the physical block indices owned by
178 /// this cache. Lets the model's release path return blocks to the
179 /// shared allocator without reading them back from device.
180 pub paged_block_indices: Vec<u32>,
181}
182
183/// The core abstraction over CUDA / Metal / CPU.
184///
185/// Key design: operations take a `&mut Self::Context` which accumulates work.
186/// - **CPU**: Context is `()` — ops execute immediately.
187/// - **Metal**: Context is a `CommandBuffer` — ops encode into it, flushed on `sync()`.
188/// - **CUDA**: Context is a `CudaStream` — ops launch on the stream, synced on `sync()`.
189///
190/// `layer_forward` passes the context through all ops in a layer.
191/// `ModelRunner` calls `sync()` only when it needs results (e.g., reading logits).
192pub trait Backend: Send + Sync + Sized + 'static {
193 type Buffer: Send + Sync;
194
195 /// Execution context that accumulates GPU work.
196 /// - CPU: `()` (no-op, ops execute inline)
197 /// - Metal: wraps a CommandBuffer
198 /// - CUDA: wraps a CudaStream
199 type Context;
200
201 /// Opaque per-backend GPTQ weight representation.
202 /// - CPU: dequantized f32 weights (run as regular GEMM)
203 /// - Metal: `()` — unsupported; `gemm_gptq` errors
204 /// - CUDA: `MarlinWeight` — pre-repacked tiles + permuted scales
205 ///
206 /// Each backend repacks raw GPTQ tensors (qweight/scales/qzeros, all
207 /// i32/f16) into its preferred format at model load time, so inference
208 /// doesn't pay the repack cost per forward pass.
209 type GptqStore: Send + Sync;
210
211 /// Single backend-specific store for **all GGUF k-quant flavours**
212 /// (Q4_K_M today; Q5_K_M / Q6_K / Q8_0 etc. become enum variants
213 /// without changing the trait shape).
214 ///
215 /// Each backend's `QuantStore` is typically an enum dispatching on
216 /// the on-disk quant type — the public API (`load_quant`,
217 /// `gemm_quant`) takes a [`QuantKind`] discriminator so callers
218 /// don't see the variant boilerplate.
219 ///
220 /// **GPTQ stays on the older [`Self::GptqStore`] path** because its
221 /// load inputs are split arrays (qweight / scales / qzeros), not
222 /// the contiguous byte payload GGUF quants ship as. A future PR can
223 /// fold GPTQ into `QuantStore` once an input-shape unification is
224 /// agreed.
225 type QuantStore: Send + Sync;
226
227 /// Create a new execution context (begin accumulating work).
228 fn new_context() -> Self::Context;
229
230 /// Flush accumulated work and wait for completion.
231 /// CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
232 fn sync(ctx: &mut Self::Context);
233
234 // ── Graph capture / replay (CUDA only) ──────────────────────────────
235 //
236 // Decode-loop optimization: eliminate per-kernel launch overhead by
237 // capturing the full step as a CUDA graph and replaying. CPU/Metal
238 // have no equivalent — defaults return `unsupported`.
239 //
240 // Flow per decode step:
241 // 1. Caller: `set_decode_state(ctx, token, step)` — memcpy to dev bufs
242 // 2. Try `replay_last_graph(ctx)`:
243 // - Ok(true): graph replayed, skip eager forward
244 // - Ok(false): no captured graph yet, run eager
245 // - Err(_): not supported, run eager
246 // 3. If running eager and in capture window:
247 // - `set_dev_state_mode(ctx, true)` so kernels use _dyn variants
248 // - `begin_graph_capture(ctx)`
249 // - run forward
250 // - `end_graph_capture(ctx)` — stores graph on ctx internally
251 // - `set_dev_state_mode(ctx, false)` — restore scalar kernels
252
253 /// Update per-step dynamic state (token id, step/pos). Fast (3x memcpy).
254 fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) {}
255
256 /// Toggle between scalar-arg kernels (normal) and `_dyn` kernels that
257 /// read their dynamic scalar args from device memory (graph-friendly).
258 fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) {}
259
260 /// Begin stream capture. Subsequent kernel launches are recorded into
261 /// a pending graph instead of executing eagerly.
262 fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
263 Err(FerrumError::unsupported("graph capture not supported"))
264 }
265
266 /// End stream capture and install the captured graph as this context's
267 /// "last graph" for future `replay_last_graph` calls.
268 fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
269 Err(FerrumError::unsupported("graph capture not supported"))
270 }
271
272 /// Replay the last captured graph. Returns `Ok(false)` if no graph
273 /// is cached; caller should run eager.
274 fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool> {
275 Ok(false)
276 }
277
278 /// Drop the cached decode graph — required when the KV cache it
279 /// was captured against is about to be freed (e.g. request release),
280 /// since the graph holds raw device pointers into that cache.
281 fn reset_graph(_ctx: &mut Self::Context) {}
282
283 // ── GPTQ (INT4 quantization) ────────────────────────────────────────
284 //
285 // Two-step: load (once per weight) → gemm (per forward). The store
286 // holds whatever backend-specific format is fastest; caller code
287 // (GptqLinear) is dtype-agnostic.
288
289 /// Repack raw GPTQ tensors into the backend's preferred format.
290 /// Called once per layer at model load time.
291 ///
292 /// Inputs are host-side slices (CPU memory) — the loader reads from
293 /// safetensors and hands them off; each backend uploads + repacks
294 /// per its own strategy. `bits` is typically 4; `group_size` is
295 /// typically 128.
296 #[allow(clippy::too_many_arguments)]
297 fn load_gptq(
298 _qweight: &[i32],
299 _scales: &[f32],
300 _qzeros: &[i32],
301 _g_idx: Option<&[i32]>,
302 _bits: u32,
303 _group_size: usize,
304 _k: usize,
305 _n: usize,
306 ) -> Result<Self::GptqStore> {
307 Err(FerrumError::unsupported(
308 "load_gptq not implemented for this backend",
309 ))
310 }
311
312 /// GEMM with pre-loaded GPTQ weights.
313 /// `out[m, n] = a[m, k] @ dequant(weight)^T`
314 fn gemm_gptq(
315 _ctx: &mut Self::Context,
316 _a: &Self::Buffer,
317 _weight: &Self::GptqStore,
318 _out: &mut Self::Buffer,
319 _m: usize,
320 ) -> Result<()> {
321 Err(FerrumError::unsupported(
322 "gemm_gptq not implemented for this backend",
323 ))
324 }
325
326 /// Load GGUF k-quant weights into the backend's preferred format.
327 ///
328 /// `kind` discriminates Q4_K / Q5_K / Q6_K / Q8_0 etc. The CPU path
329 /// typically eager-dequants to fp32; the Metal path keeps raw block
330 /// bytes in MTLBuffer and dequants per matmul into a transient fp16
331 /// buffer. Adding a new k-quant flavour is a matched pair of
332 /// `QuantStore` variant + `match` arm, not a new trait method.
333 ///
334 /// `bytes`: contiguous on-disk payload — `n_blocks × block_size`.
335 /// `n_rows`: out_features. `n_cols`: in_features. The block count
336 /// is derived per-kind from these dims.
337 fn load_quant(
338 _kind: GgufQuantType,
339 _bytes: &[u8],
340 _n_rows: usize,
341 _n_cols: usize,
342 ) -> Result<Self::QuantStore> {
343 Err(FerrumError::unsupported(
344 "load_quant not implemented for this backend",
345 ))
346 }
347
348 /// Build a fused `QuantStore` from multiple `(kind, bytes, n_rows)`
349 /// parts that share `n_cols`. Used by `GgufLoader::load_fused` when
350 /// parts have heterogeneous quant kinds (e.g. Qwen3 qkv_proj where
351 /// q+k are Q4_K but v is Q6_K) — byte-concatenation isn't possible,
352 /// so each part stays as its own QuantStore and the gemm dispatches
353 /// one matvec per part with output offsets.
354 ///
355 /// Default: not supported. Backends that have a `Fused`-like variant
356 /// override.
357 fn load_quant_fused(
358 _parts: &[(GgufQuantType, &[u8], usize)],
359 _n_cols: usize,
360 ) -> Result<Self::QuantStore> {
361 Err(FerrumError::unsupported(
362 "load_quant_fused not implemented for this backend",
363 ))
364 }
365
366 /// GEMM with k-quant weights. Mirrors `gemm` / `gemm_gptq` shape:
367 /// `out[m, n] = a[m, k] @ dequant(weight)^T`. The dispatch on the
368 /// quant flavour happens inside the backend's `QuantStore` enum.
369 fn gemm_quant(
370 _ctx: &mut Self::Context,
371 _a: &Self::Buffer,
372 _weight: &Self::QuantStore,
373 _out: &mut Self::Buffer,
374 _m: usize,
375 ) -> Result<()> {
376 Err(FerrumError::unsupported(
377 "gemm_quant not implemented for this backend",
378 ))
379 }
380
381 /// Build a stacked-experts `QuantStore` from a contiguous 3-D weight
382 /// payload `[num_experts, n_rows, n_cols/256]` super-blocks.
383 /// Used for the MoE indirect-dispatch fast path; backends without
384 /// such a kernel return `Err(unsupported)` and the model code falls
385 /// back to the per-expert loop.
386 ///
387 /// Default: not supported. Override on backends with batched MoE
388 /// kernels (e.g. Metal `gemv_q*kw_moe_id_f32`).
389 fn load_quant_experts(
390 _kind: GgufQuantType,
391 _bytes: &[u8],
392 _num_experts: usize,
393 _n_rows: usize,
394 _n_cols: usize,
395 ) -> Result<Self::QuantStore> {
396 Err(FerrumError::unsupported(
397 "load_quant_experts not implemented for this backend",
398 ))
399 }
400
401 /// MoE 2-D indirect-dispatch GEMM (prefill m > 1).
402 ///
403 /// Computes per (token, expert_slot) pair, batched across all
404 /// experts in one launch:
405 ///
406 /// `out[token, slot, :] = a[token, slot_or_0, :] @ dequant(weight[expert(token, slot), :])^T`
407 ///
408 /// `ids[expert][slot] = pair_id` encodes `(token_idx, slot_within_token)`
409 /// so the kernel reads activations indirectly (src1 row for the
410 /// pair) and writes outputs directly to the natural
411 /// `[batch, top_k, M]` layout. `tpe[expert]` gives the count of
412 /// pairs assigned to each expert — threadgroups past `tpe[e]`
413 /// early-exit.
414 ///
415 /// `ne11` selects the src1 inner-batch shape:
416 /// - `1` for `gate` / `up` (broadcast — all slots read the same
417 /// activation row per token).
418 /// - `top_k` for `down` (per-slot — each pair reads its own row in
419 /// the upstream silu·gate output).
420 ///
421 /// Closes the prefill MoE gap: the per-token gemv loop becomes one
422 /// batched gemm where each expert's slab handles m ≈ batch·top_k /
423 /// num_experts pairs in parallel via simdgroup_half8x8 matmul.
424 #[allow(clippy::too_many_arguments)]
425 fn gemm_quant_moe_id(
426 _ctx: &mut Self::Context,
427 _a: &Self::Buffer,
428 _weight: &Self::QuantStore,
429 _ids: &Self::Buffer,
430 _tpe: &Self::Buffer,
431 _out: &mut Self::Buffer,
432 _ne11: usize,
433 _top_k: usize,
434 _max_per_expert: usize,
435 _batch: usize,
436 ) -> Result<()> {
437 Err(FerrumError::unsupported(
438 "gemm_quant_moe_id not implemented for this backend",
439 ))
440 }
441
442 /// GPU-side MoE router: `[batch, num_experts]` logits → `[batch, top_k]`
443 /// expert IDs (i32) + `[batch, top_k]` combine weights (f32).
444 ///
445 /// Replaces the per-layer `B::sync + B::to_vec(router_logits) + host route()`
446 /// round trip. The output buffers stay device-side for downstream
447 /// `gemv_quant_moe_id` / `gemm_quant_moe_id` consumption — no host
448 /// pipeline drain in the inner loop.
449 ///
450 /// `norm_topk_prob`: if true, divide each row's K weights by their
451 /// sum so they total 1.0 (Qwen3-MoE / Mixtral default).
452 #[allow(clippy::too_many_arguments)]
453 fn route_topk_softmax(
454 _ctx: &mut Self::Context,
455 _logits: &Self::Buffer,
456 _out_ids: &mut Self::Buffer,
457 _out_weights: &mut Self::Buffer,
458 _batch: usize,
459 _num_experts: usize,
460 _top_k: usize,
461 _norm_topk_prob: bool,
462 ) -> Result<()> {
463 Err(FerrumError::unsupported(
464 "route_topk_softmax not implemented for this backend",
465 ))
466 }
467
468 /// GPU-side bucket sort: turn `[batch, top_k]` selected expert IDs
469 /// (from [`Self::route_topk_softmax`]) into `tpe[num_experts]` /
470 /// `ids[num_experts * row_stride]` arrays consumed by the batched
471 /// MoE GEMM, and emit indirect-dispatch args for the consumer GEMM.
472 ///
473 /// The `ids` buffer's row stride is `batch * top_k` (worst case);
474 /// only the first `tpe[e]` entries of each row are populated. The
475 /// consumer GEMM kernel early-exits at `r1 >= tpe[e]`, so the over-
476 /// strided indices cost nothing in the inner loop. The grid size,
477 /// however, would still be worst-case unless we tighten it — this
478 /// is what the `gate_up_args` / `down_args` outputs do: a 12-byte
479 /// `(grid_x, grid_y, grid_z)` u32 triple per shape, ready for
480 /// `dispatch_thread_groups_indirect`. `grid_x` is shared (depends
481 /// only on `max(tpe[e])`); `grid_y` differs because gate/up has
482 /// `M = m_gate_up` while down has `M = m_down`.
483 ///
484 /// All five output buffers are written in one kernel; no host
485 /// roundtrip and no per-layer pipeline drain.
486 #[allow(clippy::too_many_arguments)]
487 fn compute_ids_tpe_gpu(
488 _ctx: &mut Self::Context,
489 _selected_ids: &Self::Buffer,
490 _tpe: &mut Self::Buffer,
491 _ids: &mut Self::Buffer,
492 _gate_up_args: &mut Self::Buffer,
493 _down_args: &mut Self::Buffer,
494 _batch: usize,
495 _num_experts: usize,
496 _top_k: usize,
497 _m_gate_up: usize,
498 _m_down: usize,
499 ) -> Result<()> {
500 Err(FerrumError::unsupported(
501 "compute_ids_tpe_gpu not implemented for this backend",
502 ))
503 }
504
505 /// Indirect-dispatch variant of `gemm_quant_moe_id`.
506 ///
507 /// Identical inputs except the grid is read from `args_buf` (a 12-
508 /// byte u32 triple written by `compute_ids_tpe_gpu`) instead of
509 /// being computed from `max_per_expert`. `max_per_expert` is still
510 /// the kernel parameter used as the row stride for `ids` indexing
511 /// (= `batch * top_k`, worst case); only the dispatched grid
512 /// shrinks to cover `max(tpe[e])` columns.
513 #[allow(clippy::too_many_arguments)]
514 fn gemm_quant_moe_id_indirect(
515 _ctx: &mut Self::Context,
516 _src1: &Self::Buffer,
517 _weights: &Self::QuantStore,
518 _ids: &Self::Buffer,
519 _tpe: &Self::Buffer,
520 _out: &mut Self::Buffer,
521 _args_buf: &Self::Buffer,
522 _ne11: usize,
523 _top_k: usize,
524 _max_per_expert: usize,
525 _batch: usize,
526 ) -> Result<()> {
527 Err(FerrumError::unsupported(
528 "gemm_quant_moe_id_indirect not implemented for this backend",
529 ))
530 }
531
532 /// Stacked SiLU·gate over `[batch * top_k, ffn]` rows (prefill version
533 /// of `silu_mul_stacked`).
534 fn silu_mul_batched(
535 _ctx: &mut Self::Context,
536 _gate: &Self::Buffer,
537 _up: &Self::Buffer,
538 _out: &mut Self::Buffer,
539 _total_pairs: usize,
540 _ffn: usize,
541 ) -> Result<()> {
542 Err(FerrumError::unsupported(
543 "silu_mul_batched not implemented for this backend",
544 ))
545 }
546
547 /// Fused weighted-sum + residual-add: `residual[i] += Σ_k weights[k] · slots[k, i]`.
548 /// Single dispatch replaces the (weighted_sum → moe_out) +
549 /// (add_inplace residual += moe_out) pair on the decode hot path.
550 fn weighted_sum_residual_stacked(
551 _ctx: &mut Self::Context,
552 _slots: &Self::Buffer,
553 _weights: &Self::Buffer,
554 _residual: &mut Self::Buffer,
555 _n_slots: usize,
556 _hidden: usize,
557 ) -> Result<()> {
558 Err(FerrumError::unsupported(
559 "weighted_sum_residual_stacked not implemented for this backend",
560 ))
561 }
562
563 /// Fused weighted-sum-residual + RMSNorm: combines this layer's
564 /// `weighted_sum_residual_stacked` with the next layer's leading
565 /// `rms_norm` into a single dispatch.
566 ///
567 /// Computes
568 /// `residual[i] += Σ_s w[s] · slots[s, i]`
569 /// `normed_out[i] = residual[i] · (1 / sqrt(Σ residual² / hidden + eps)) · next_norm_w[i]`
570 ///
571 /// Caller is responsible for skipping the next layer's standalone
572 /// `rms_norm` — `normed_out` IS that layer's `norm_out` input.
573 /// Default returns Unsupported.
574 #[allow(clippy::too_many_arguments)]
575 fn weighted_sum_residual_norm_stacked(
576 _ctx: &mut Self::Context,
577 _slots: &Self::Buffer,
578 _weights: &Self::Buffer,
579 _residual: &mut Self::Buffer,
580 _next_norm_w: &Self::Buffer,
581 _normed_out: &mut Self::Buffer,
582 _n_slots: usize,
583 _hidden: usize,
584 _eps: f32,
585 ) -> Result<()> {
586 Err(FerrumError::unsupported(
587 "weighted_sum_residual_norm_stacked not implemented for this backend",
588 ))
589 }
590
591 /// Per-batch weighted sum: `out[b, h] = Σ_k weights[b, k] · slots[b, k, h]`.
592 /// Single dispatch covers the whole batch (prefill version of
593 /// `weighted_sum_stacked` which only handled one token).
594 fn weighted_sum_batched(
595 _ctx: &mut Self::Context,
596 _slots: &Self::Buffer,
597 _weights: &Self::Buffer,
598 _out: &mut Self::Buffer,
599 _batch: usize,
600 _top_k: usize,
601 _hidden: usize,
602 ) -> Result<()> {
603 Err(FerrumError::unsupported(
604 "weighted_sum_batched not implemented for this backend",
605 ))
606 }
607
608 /// Offset-aware variant of [`Self::weighted_sum_batched`] —
609 /// `weights` reads from `weights_offset` (in elements, points at
610 /// the start of `[batch, top_k]`), `out` writes from `out_offset`
611 /// (in elements, points at start of `[batch, hidden]`). Used by
612 /// the per-item batched-decode path to skip `copy_slice` round-trips.
613 /// Default falls back to the non-offset variant via two copies.
614 #[allow(clippy::too_many_arguments)]
615 fn weighted_sum_batched_offset(
616 ctx: &mut Self::Context,
617 slots: &Self::Buffer,
618 weights: &Self::Buffer,
619 weights_offset: usize,
620 out: &mut Self::Buffer,
621 out_offset: usize,
622 batch: usize,
623 top_k: usize,
624 hidden: usize,
625 ) -> Result<()> {
626 // Default: stage through scratch — backends override for zero-copy.
627 let _ = (
628 ctx,
629 slots,
630 weights,
631 weights_offset,
632 out,
633 out_offset,
634 batch,
635 top_k,
636 hidden,
637 );
638 Err(FerrumError::unsupported(
639 "weighted_sum_batched_offset not implemented for this backend",
640 ))
641 }
642
643 /// MoE indirect-dispatch GEMV: `out[i, :] = a[i, :] @ dequant(weight[ids[i], :])^T`
644 /// for each `i ∈ [0, n_selected)`. Single backend dispatch covers
645 /// all selected (token, expert) pairs.
646 ///
647 /// `weight` must be a stacked-experts variant produced by
648 /// [`Self::load_quant_experts`]. `ids` is a backend-side buffer of
649 /// `n_selected` i32 expert IDs. `out` is sized `[n_selected, n_rows]`.
650 /// `src1_stride` is the per-slot activation stride in **elements**:
651 /// `0` ⇒ every slot reads the same activation row (broadcast — for
652 /// `gate` / `up` projections); `n_cols` ⇒ each slot reads its own
653 /// activation row (for `down` projections, where each expert
654 /// consumes its own silu(gate)·up output).
655 fn gemv_quant_moe_id(
656 _ctx: &mut Self::Context,
657 _a: &Self::Buffer,
658 _weight: &Self::QuantStore,
659 _ids: &Self::Buffer,
660 _out: &mut Self::Buffer,
661 _n_selected: usize,
662 _src1_stride: usize,
663 ) -> Result<()> {
664 Err(FerrumError::unsupported(
665 "gemv_quant_moe_id not implemented for this backend",
666 ))
667 }
668
669 /// Offset-aware variant of [`Self::gemv_quant_moe_id`] — reads `a`
670 /// from `a_offset` (in elements; meaningful only when src1_stride=0
671 /// for the broadcast case, or as the start of an `n_selected × K`
672 /// strided read when src1_stride≥K), reads `ids` from `ids_offset`
673 /// (the i-th `top_k` block in a stacked-batch `[M, top_k]` ids
674 /// buffer), and writes `out` from offset 0 (output stays per-iter
675 /// scratch). Used by the per-item batched-decode path so the M=N
676 /// concurrent decodes can read directly from the M-batch
677 /// `selected_ids_buf` / `norm_out` without materialising
678 /// per-iteration copies.
679 #[allow(clippy::too_many_arguments)]
680 fn gemv_quant_moe_id_offset(
681 ctx: &mut Self::Context,
682 a: &Self::Buffer,
683 a_offset: usize,
684 weight: &Self::QuantStore,
685 ids: &Self::Buffer,
686 ids_offset: usize,
687 out: &mut Self::Buffer,
688 n_selected: usize,
689 src1_stride: usize,
690 ) -> Result<()> {
691 let _ = (
692 ctx,
693 a,
694 a_offset,
695 weight,
696 ids,
697 ids_offset,
698 out,
699 n_selected,
700 src1_stride,
701 );
702 Err(FerrumError::unsupported(
703 "gemv_quant_moe_id_offset not implemented for this backend",
704 ))
705 }
706
707 /// Allocate a backend buffer of i32-typed values for kernels that
708 /// need integer indices (MoE expert IDs, scatter indices, etc.).
709 ///
710 /// Default impl bit-casts the i32s to f32s and uploads via
711 /// `from_slice` — useful on backends where the buffer type is type-
712 /// erased (CPU's `Vec<f32>`, Metal's untyped MTLBuffer). Backends
713 /// that use a strongly-typed buffer override.
714 fn from_slice_i32(data: &[i32]) -> Self::Buffer {
715 let f: Vec<f32> = data.iter().map(|&i| f32::from_bits(i as u32)).collect();
716 Self::from_slice(&f)
717 }
718
719 /// Overwrite an existing i32 buffer's contents in place. Used on
720 /// the MoE decode hot path: per-layer expert-id updates do an
721 /// in-place memcpy instead of allocating a fresh device buffer
722 /// (48 layers × 128 tokens = 6144 fresh allocations per decode
723 /// run otherwise — allocator pressure dominates the secondary cost).
724 ///
725 /// Default impl falls back to `from_slice_i32` + drop. Backends
726 /// with shared CPU↔GPU memory (Metal `StorageModeShared`, CPU's
727 /// `Vec<f32>`) override with a direct write.
728 fn write_i32_into(buf: &mut Self::Buffer, data: &[i32]) {
729 *buf = Self::from_slice_i32(data);
730 }
731
732 /// Overwrite an existing f32 buffer's contents in place. Counterpart
733 /// to `write_i32_into` for f32 data — used to update the per-token
734 /// MoE combine weights into a pre-allocated scratch buffer instead
735 /// of allocating a fresh `from_slice` buffer 6144 times per decode
736 /// run.
737 fn write_f32_into(buf: &mut Self::Buffer, data: &[f32]) {
738 *buf = Self::from_slice(data);
739 }
740
741 /// Stacked SiLU·gate over `[n_slots, ffn]` rows.
742 ///
743 /// Computes `out[s, i] = silu(gate[s, i]) * up[s, i]` for each slot
744 /// `s`, element `i`. Single dispatch covers all slots — cuts the
745 /// MoE decode silu staging from `top_k * (3 copy_slice + 1 silu)`
746 /// = 32 dispatches per layer to 1.
747 fn silu_mul_stacked(
748 _ctx: &mut Self::Context,
749 _gate: &Self::Buffer,
750 _up: &Self::Buffer,
751 _out: &mut Self::Buffer,
752 _n_slots: usize,
753 _ffn: usize,
754 ) -> Result<()> {
755 Err(FerrumError::unsupported(
756 "silu_mul_stacked not implemented for this backend",
757 ))
758 }
759
760 /// Fused gate+up MoE GEMV with in-register `SiLU(gate) * up`.
761 ///
762 /// Folds the three back-to-back dispatches that the stacked MoE
763 /// FFN decode path emitted per layer:
764 /// 1. `gemv_quant_moe_id` (gate) → gate_out_stacked
765 /// 2. `gemv_quant_moe_id` (up) → up_out_stacked
766 /// 3. `silu_mul_stacked` → silu_stacked
767 /// into a single dispatch that writes `silu_stacked` directly.
768 /// Saves 2 dispatches per layer plus the entire round-trip through
769 /// the gate_out / up_out scratch buffers (≈4× `[top_k, ffn]` of
770 /// intermediate traffic). The activation read is also halved
771 /// because the inner Q4_K reduction reuses one register-file load
772 /// across both weight matrices.
773 ///
774 /// Both `gate_w` and `up_w` must be `Q4KExperts` stacks with
775 /// matching `(num_experts, n_rows, n_cols)` (true for Qwen3-MoE
776 /// GGUFs). Backends without the fused kernel can fall back to the
777 /// 3-dispatch path; callers should gate via
778 /// [`Self::supports_fused_moe_gate_up_silu`] to avoid the
779 /// `Unsupported` String-allocating error round trip on the decode
780 /// hot path.
781 #[allow(clippy::too_many_arguments)]
782 fn gemv_quant_moe_id_gate_up_silu(
783 _ctx: &mut Self::Context,
784 _a: &Self::Buffer,
785 _gate_w: &Self::QuantStore,
786 _up_w: &Self::QuantStore,
787 _ids: &Self::Buffer,
788 _silu_out: &mut Self::Buffer,
789 _n_selected: usize,
790 ) -> Result<()> {
791 Err(FerrumError::unsupported(
792 "gemv_quant_moe_id_gate_up_silu not implemented for this backend",
793 ))
794 }
795
796 /// Capability probe for [`Self::gemv_quant_moe_id_gate_up_silu`].
797 ///
798 /// `true` ⇒ the fused kernel is wired in and the caller should
799 /// prefer it on the MoE decode hot path. `false` ⇒ caller must use
800 /// the 3-dispatch fallback (gate gemv + up gemv + silu_mul_stacked).
801 /// Lets callers branch without paying the cost of an `Err(Unsupported)`
802 /// allocation per (layer, step).
803 fn supports_fused_moe_gate_up_silu() -> bool {
804 false
805 }
806
807 /// Batched MoE indirect-dispatch GEMV — one Metal launch covers
808 /// **all** `m * top_k` (token, expert) pairs at once.
809 ///
810 /// This is the symmetric counterpart of
811 /// [`Self::gemv_quant_moe_id`]: same Q4_K decode loop, same
812 /// per-pair output, but the grid Z-axis spans `m * top_k` instead
813 /// of just `top_k`. Eliminates the engine-level per-token outer
814 /// loop that emits ~16× the dispatches llama.cpp emits at c=16
815 /// (their `kernel_mul_mv_id` already handles the M batch in one
816 /// dispatch).
817 ///
818 /// `a` : activation buffer; pair `p` reads
819 /// `(p / top_k) * src1_outer_stride
820 /// + (p % top_k) * src1_inner_stride` floats.
821 /// gate / up: src1 = `norm_out [m, K]`,
822 /// outer = K, inner = 0
823 /// (slots within a token broadcast).
824 /// down: src1 = `silu_stacked [m, top_k, K]`,
825 /// outer = top_k * K, inner = K.
826 /// `weight` : Q4KExperts stacked weights, common across
827 /// selected experts.
828 /// `ids` : flat `[m * top_k]` selected-expert IDs (i32).
829 /// `out` : `[m * top_k, n_rows]` outputs.
830 /// `m` : token batch size.
831 /// `top_k` : selected experts per token.
832 /// `src1_outer_stride`, `src1_inner_stride`: in **floats**.
833 #[allow(clippy::too_many_arguments)]
834 fn gemv_quant_moe_id_batched(
835 _ctx: &mut Self::Context,
836 _a: &Self::Buffer,
837 _weight: &Self::QuantStore,
838 _ids: &Self::Buffer,
839 _out: &mut Self::Buffer,
840 _m: usize,
841 _top_k: usize,
842 _src1_outer_stride: usize,
843 _src1_inner_stride: usize,
844 ) -> Result<()> {
845 Err(FerrumError::unsupported(
846 "gemv_quant_moe_id_batched not implemented for this backend",
847 ))
848 }
849
850 /// Capability probe for [`Self::gemv_quant_moe_id_batched`].
851 fn supports_batched_moe_gemv() -> bool {
852 false
853 }
854
855 /// Batched fused gate+up MoE GEMV with in-register `SiLU(gate) * up`.
856 ///
857 /// Counterpart of [`Self::gemv_quant_moe_id_gate_up_silu`] for the
858 /// batched-decode path: same in-register fusion, but the grid Z
859 /// dimension covers all `m * top_k` (token, expert) pairs in one
860 /// dispatch. Folds the three batched MoE FFN dispatches per layer
861 /// (gate gemv + up gemv + silu_mul_batched) into one — the missing
862 /// fusion that left the m≥2 batched-decode path slower than the
863 /// per-token loop (which already had this fusion at m=1).
864 ///
865 /// Both `gate_w` and `up_w` must be `Q4KExperts` stacks with
866 /// matching `(num_experts, n_rows, n_cols)`.
867 #[allow(clippy::too_many_arguments)]
868 fn gemv_quant_moe_id_gate_up_silu_batched(
869 _ctx: &mut Self::Context,
870 _a: &Self::Buffer,
871 _gate_w: &Self::QuantStore,
872 _up_w: &Self::QuantStore,
873 _ids: &Self::Buffer,
874 _silu_out: &mut Self::Buffer,
875 _m: usize,
876 _top_k: usize,
877 _src1_outer_stride: usize,
878 _src1_inner_stride: usize,
879 ) -> Result<()> {
880 Err(FerrumError::unsupported(
881 "gemv_quant_moe_id_gate_up_silu_batched not implemented for this backend",
882 ))
883 }
884
885 /// Capability probe for [`Self::gemv_quant_moe_id_gate_up_silu_batched`].
886 fn supports_batched_moe_gate_up_silu() -> bool {
887 false
888 }
889
890 /// Weighted sum across `n_slots` rows of `[hidden]`.
891 ///
892 /// Computes `out[i] = Σ_s weights[s] * slots[s, i]`. Single
893 /// dispatch replaces the per-slot `(copy_slice + scaled_add)`
894 /// loop in the MoE decode path (16 dispatches per layer → 1).
895 fn weighted_sum_stacked(
896 _ctx: &mut Self::Context,
897 _slots: &Self::Buffer,
898 _weights: &Self::Buffer,
899 _out: &mut Self::Buffer,
900 _n_slots: usize,
901 _hidden: usize,
902 ) -> Result<()> {
903 Err(FerrumError::unsupported(
904 "weighted_sum_stacked not implemented for this backend",
905 ))
906 }
907
908 // ── GEMM ────────────────────────────────────────────────────────────
909
910 fn gemm(
911 ctx: &mut Self::Context,
912 a: &Self::Buffer,
913 b: &Self::Buffer,
914 out: &mut Self::Buffer,
915 m: usize,
916 n: usize,
917 k: usize,
918 );
919
920 // ── Norms ───────────────────────────────────────────────────────────
921
922 fn rms_norm(
923 ctx: &mut Self::Context,
924 x: &Self::Buffer,
925 w: &Self::Buffer,
926 eps: f32,
927 out: &mut Self::Buffer,
928 tokens: usize,
929 dim: usize,
930 );
931
932 fn fused_add_rms_norm(
933 ctx: &mut Self::Context,
934 residual: &mut Self::Buffer,
935 x: &Self::Buffer,
936 w: &Self::Buffer,
937 eps: f32,
938 out: &mut Self::Buffer,
939 tokens: usize,
940 dim: usize,
941 );
942
943 // ── Attention ───────────────────────────────────────────────────────
944
945 fn flash_attention(
946 ctx: &mut Self::Context,
947 q: &Self::Buffer,
948 k: &Self::Buffer,
949 v: &Self::Buffer,
950 out: &mut Self::Buffer,
951 batch: usize,
952 q_len: usize,
953 kv_len: usize,
954 pos_offset: usize,
955 cfg: &AttnConfig,
956 );
957
958 /// Multi-Head Latent Attention — DeepSeek V2 / V3's compressed-KV
959 /// attention variant. Extension point only; no backend implements it
960 /// yet. DeepSeek V3 landing in Phase D/E will fill this in.
961 ///
962 /// `q`: full Q `[batch, num_heads, q_len, head_dim]`
963 /// `kv_compressed`: latent KV `[batch, kv_len, kv_lora_rank]`
964 /// `kv_rope`: per-position rope-applied key heads `[batch, kv_len, qk_rope_head_dim]`
965 /// `out`: `[batch, num_heads, q_len, head_dim]`
966 #[allow(clippy::too_many_arguments)]
967 fn mla_attention(
968 _ctx: &mut Self::Context,
969 _q: &Self::Buffer,
970 _kv_compressed: &Self::Buffer,
971 _kv_rope: &Self::Buffer,
972 _out: &mut Self::Buffer,
973 _batch: usize,
974 _q_len: usize,
975 _kv_len: usize,
976 _pos_offset: usize,
977 _cfg: &AttnConfig,
978 _kv_lora_rank: usize,
979 _qk_rope_head_dim: usize,
980 ) -> Result<()> {
981 Err(FerrumError::unsupported(
982 "mla_attention not implemented for this backend; required by \
983 DeepSeek V2/V3 (Phase D/E)",
984 ))
985 }
986
987 // ── Element-wise ────────────────────────────────────────────────────
988 //
989 // Models use `add_inplace` for residual updates and `copy_slice` for the
990 // row-extraction step in prefill. Offset-free copy / non-inplace add are
991 // not needed by the current Model-as-Code path; they can return later if
992 // a model actually requires them.
993
994 /// Copy `len` floats from `src[src_offset..]` to `dst[dst_offset..]`.
995 ///
996 /// Needed for Qwen3Model::prefill to pluck the last token's hidden state
997 /// out of `residual[seq_len, h]` without round-tripping through host RAM.
998 /// `Backend::copy` is the offset-free variant; `copy_slice` additionally
999 /// supports non-zero source and destination offsets.
1000 fn copy_slice(
1001 ctx: &mut Self::Context,
1002 src: &Self::Buffer,
1003 src_offset: usize,
1004 dst: &mut Self::Buffer,
1005 dst_offset: usize,
1006 len: usize,
1007 );
1008
1009 // ── Embedding ───────────────────────────────────────────────────────
1010
1011 fn embedding_lookup(
1012 ctx: &mut Self::Context,
1013 table: &Self::Buffer,
1014 ids: &[u32],
1015 out: &mut Self::Buffer,
1016 dim: usize,
1017 );
1018
1019 // ── Transformer-specific fused ops ─────────────────────────────────
1020 // These avoid CPU round-trips for data layout transformations.
1021
1022 /// Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers.
1023 /// Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
1024 fn split_qkv(
1025 ctx: &mut Self::Context,
1026 qkv: &Self::Buffer,
1027 q: &mut Self::Buffer,
1028 k: &mut Self::Buffer,
1029 v: &mut Self::Buffer,
1030 tokens: usize,
1031 q_dim: usize,
1032 kv_dim: usize,
1033 );
1034
1035 /// Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im],
1036 /// then compute SiLU(gate) * up → out [tokens, im].
1037 fn fused_silu_mul_split(
1038 ctx: &mut Self::Context,
1039 gate_up: &Self::Buffer,
1040 out: &mut Self::Buffer,
1041 tokens: usize,
1042 im: usize,
1043 );
1044
1045 /// Fused QK-norm + RoPE + transpose-to-head-major.
1046 ///
1047 /// `mode` selects the operation:
1048 /// 0 = transpose only (typical for V, which needs no norm and no RoPE)
1049 /// 1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3)
1050 /// 2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)
1051 ///
1052 /// input: `[tokens, heads, head_dim]` (token-major, output of split_qkv)
1053 /// output: `[heads, tokens, head_dim]` (head-major, ready for flash_attn / kv_cache_append)
1054 ///
1055 /// `pos_offset` is the position of token 0 (decode uses current seq len;
1056 /// prefill uses 0). Within the batch, positions are taken as `pos_offset + i`.
1057 ///
1058 /// This is the primary attention-input preparation op. Backends that have a
1059 /// fused kernel (Metal's `qk_norm_rope_transpose_f32`) will be dramatically
1060 /// faster than composing norm + rope + transpose separately; the CPU
1061 /// fallback lowers to the individual ops.
1062 #[allow(clippy::too_many_arguments)]
1063 fn qk_norm_rope(
1064 ctx: &mut Self::Context,
1065 input: &Self::Buffer,
1066 norm_w: &Self::Buffer,
1067 cos: &Self::Buffer,
1068 sin: &Self::Buffer,
1069 output: &mut Self::Buffer,
1070 tokens: usize,
1071 heads: usize,
1072 head_dim: usize,
1073 pos_offset: usize,
1074 eps: f32,
1075 mode: i32,
1076 );
1077
1078 /// Fused split-QKV + QK-norm + RoPE + head-major transpose.
1079 ///
1080 /// Single-dispatch replacement for the (`split_qkv` → 3× `qk_norm_rope`)
1081 /// chain on the decode-attention prelude. Reads the linear-layer
1082 /// fused-QKV output once and writes head-major Q/K/V directly into
1083 /// attention scratch.
1084 ///
1085 /// `qkv` layout: `[tokens, q_heads*hd + 2*kv_heads*hd]`.
1086 /// `q_out`: `[q_heads, tokens, hd]`. `k_out`/`v_out`: `[kv_heads, tokens, hd]`.
1087 /// `qk_mode`: 1 = norm + RoPE for Q/K (Qwen3 with QK-norm),
1088 /// 2 = RoPE only for Q/K (no QK-norm; Llama-style).
1089 /// V always falls through to transpose-only.
1090 ///
1091 /// Default returns Unsupported. Backends that implement it are
1092 /// expected to be dramatically faster than the four-dispatch chain.
1093 #[allow(clippy::too_many_arguments)]
1094 fn split_qkv_norm_rope(
1095 _ctx: &mut Self::Context,
1096 _qkv: &Self::Buffer,
1097 _q_norm_w: &Self::Buffer,
1098 _k_norm_w: &Self::Buffer,
1099 _cos: &Self::Buffer,
1100 _sin: &Self::Buffer,
1101 _q_out: &mut Self::Buffer,
1102 _k_out: &mut Self::Buffer,
1103 _v_out: &mut Self::Buffer,
1104 _tokens: usize,
1105 _q_heads: usize,
1106 _kv_heads: usize,
1107 _head_dim: usize,
1108 _pos_offset: usize,
1109 _eps: f32,
1110 _qk_mode: i32,
1111 ) -> Result<()> {
1112 Err(FerrumError::unsupported(
1113 "split_qkv_norm_rope not implemented for this backend",
1114 ))
1115 }
1116
1117 /// Variant of [`Backend::split_qkv_norm_rope`] that writes the new
1118 /// K and V directly into pre-allocated head-major KV cache buffers
1119 /// at slot `[kv_heads, cache_len .. cache_len + tokens, hd]`.
1120 /// Eliminates the trailing `kv_cache_append_head_major` dispatch on
1121 /// the decode hot path. Q still lands in per-token head-major
1122 /// scratch (flash-attention reads it as the query).
1123 ///
1124 /// Default returns Unsupported. Backends without the fused kernel
1125 /// can keep using `split_qkv_norm_rope` + `kv_cache_append_head_major`.
1126 #[allow(clippy::too_many_arguments)]
1127 fn split_qkv_norm_rope_into_cache(
1128 _ctx: &mut Self::Context,
1129 _qkv: &Self::Buffer,
1130 _q_norm_w: &Self::Buffer,
1131 _k_norm_w: &Self::Buffer,
1132 _cos: &Self::Buffer,
1133 _sin: &Self::Buffer,
1134 _q_out: &mut Self::Buffer,
1135 _cache_k: &mut Self::Buffer,
1136 _cache_v: &mut Self::Buffer,
1137 _tokens: usize,
1138 _q_heads: usize,
1139 _kv_heads: usize,
1140 _head_dim: usize,
1141 _pos_offset: usize,
1142 _eps: f32,
1143 _qk_mode: i32,
1144 _cache_len: usize,
1145 _cache_capacity: usize,
1146 ) -> Result<()> {
1147 Err(FerrumError::unsupported(
1148 "split_qkv_norm_rope_into_cache not implemented for this backend",
1149 ))
1150 }
1151
1152 /// Paged-KV variant of [`Self::split_qkv_norm_rope_into_cache`].
1153 ///
1154 /// Same fused split + qk-norm + RoPE, but K/V are written into a
1155 /// paged pool `[num_blocks, kv_heads, block_size, head_dim]`
1156 /// indexed via `block_table[logical_block]` → physical_block.
1157 /// Q still goes to head-major scratch.
1158 ///
1159 /// Default returns Unsupported. Backends that lack a paged kernel
1160 /// keep using the contiguous variant.
1161 /// `qkv_byte_offset` / `q_out_byte_offset` let the caller pass a
1162 /// slice of a larger batched buffer (used by the multi-seq paged
1163 /// path in `decode_batch_internal`). For single-seq dispatch they
1164 /// should be 0.
1165 #[allow(clippy::too_many_arguments)]
1166 fn split_qkv_norm_rope_into_paged_cache(
1167 _ctx: &mut Self::Context,
1168 _qkv: &Self::Buffer,
1169 _qkv_byte_offset: u64,
1170 _q_norm_w: &Self::Buffer,
1171 _k_norm_w: &Self::Buffer,
1172 _cos: &Self::Buffer,
1173 _sin: &Self::Buffer,
1174 _q_out: &mut Self::Buffer,
1175 _q_out_byte_offset: u64,
1176 _cache_k: &mut Self::Buffer,
1177 _cache_v: &mut Self::Buffer,
1178 _block_table: &Self::Buffer,
1179 _tokens: usize,
1180 _q_heads: usize,
1181 _kv_heads: usize,
1182 _head_dim: usize,
1183 _pos_offset: usize,
1184 _eps: f32,
1185 _qk_mode: i32,
1186 _cache_len: usize,
1187 _block_size: usize,
1188 _max_num_blocks_per_seq: usize,
1189 ) -> Result<()> {
1190 Err(FerrumError::unsupported(
1191 "split_qkv_norm_rope_into_paged_cache not implemented for this backend",
1192 ))
1193 }
1194
1195 /// Paged-KV variant of [`Self::flash_attention`].
1196 ///
1197 /// Decode (`q_len == 1`):
1198 /// `q`/`out`: `[num_seqs, num_heads, head_dim]` (token-major)
1199 ///
1200 /// Causal prefill (`q_len > 1`, single seq):
1201 /// `q`/`out`: `[num_heads, q_len, head_dim]` (head-major — the
1202 /// layout produced by `split_qkv_norm_rope_into_paged_cache`)
1203 /// The kernel applies a per-q-token causal mask using
1204 /// `context_lens[seq]` as the FINAL kv_len (= `pos_offset + q_len`):
1205 /// token i sees positions `[0, context_lens - q_len + 1 + i)`.
1206 ///
1207 /// Common to both:
1208 /// `k_pool`/`v_pool`: `[num_blocks, num_kv_heads, block_size, head_dim]`
1209 /// `block_tables`: `[num_seqs, max_num_blocks_per_seq]` u32
1210 /// `context_lens`: `[num_seqs]` u32
1211 ///
1212 /// Backends without a paged kernel return Unsupported; callers are
1213 /// expected to fall back to contiguous KV.
1214 #[allow(clippy::too_many_arguments)]
1215 fn paged_decode_attention(
1216 _ctx: &mut Self::Context,
1217 _q: &Self::Buffer,
1218 _k_pool: &Self::Buffer,
1219 _v_pool: &Self::Buffer,
1220 _out: &mut Self::Buffer,
1221 _block_tables: &Self::Buffer,
1222 _context_lens: &Self::Buffer,
1223 _num_seqs: usize,
1224 _num_heads: usize,
1225 _num_kv_heads: usize,
1226 _head_dim: usize,
1227 _block_size: usize,
1228 _max_num_blocks_per_seq: usize,
1229 _q_len: usize,
1230 ) -> Result<()> {
1231 Err(FerrumError::unsupported(
1232 "paged_decode_attention not implemented for this backend",
1233 ))
1234 }
1235
1236 /// Allocate a u32 buffer of length `n` for paged-KV bookkeeping
1237 /// (block tables, context lens). Default uses the existing
1238 /// `from_slice_i32` route then bit-casts; backends with a faster
1239 /// path can override.
1240 fn alloc_u32(n: usize) -> Self::Buffer {
1241 // Reinterpret as i32 — same 4-byte word; the kernel reads
1242 // bytes via `device const uint32_t *`.
1243 Self::from_slice_i32(&vec![0i32; n])
1244 }
1245
1246 /// Write a u32 slice into a buffer previously allocated via
1247 /// [`Self::alloc_u32`]. Used for live block_tables / context_lens
1248 /// updates between decode steps.
1249 ///
1250 /// Default: reads back, mutates host-side, writes back. Metal
1251 /// backend overrides with a direct memcpy on the StorageModeShared
1252 /// buffer.
1253 fn write_u32(_ctx: &mut Self::Context, _dst: &mut Self::Buffer, _data: &[u32]) {
1254 // No-op default — most backends won't exercise this path until
1255 // they implement paged_decode_attention.
1256 }
1257
1258 /// Append new K/V into a pre-allocated head-major cache buffer.
1259 ///
1260 /// `cache_k` / `cache_v`: `[nkv, capacity, hd]` (head-major, pre-allocated)
1261 /// `new_k_head_major` / `new_v_head_major`: `[nkv, new_tokens, hd]`
1262 /// — produced directly by `qk_norm_rope`, no extra transpose needed.
1263 ///
1264 /// In-place append at slot `[nkv, cache_len..cache_len+new_tokens, hd]`.
1265 /// Caller owns `cache_len` bookkeeping.
1266 #[allow(clippy::too_many_arguments)]
1267 fn kv_cache_append_head_major(
1268 ctx: &mut Self::Context,
1269 cache_k: &mut Self::Buffer,
1270 cache_v: &mut Self::Buffer,
1271 cache_len: usize,
1272 cache_capacity: usize,
1273 new_k_head_major: &Self::Buffer,
1274 new_v_head_major: &Self::Buffer,
1275 new_tokens: usize,
1276 nkv: usize,
1277 hd: usize,
1278 );
1279
1280 /// Transpose [heads, tokens, dim] → [tokens, heads, dim].
1281 /// Called after `flash_attention` to restore token-major layout for O-proj.
1282 fn transpose_head_to_token(
1283 ctx: &mut Self::Context,
1284 src: &Self::Buffer,
1285 dst: &mut Self::Buffer,
1286 tokens: usize,
1287 heads: usize,
1288 dim: usize,
1289 );
1290
1291 /// residual[i] += x[i] (in-place)
1292 fn add_inplace(
1293 ctx: &mut Self::Context,
1294 residual: &mut Self::Buffer,
1295 x: &Self::Buffer,
1296 len: usize,
1297 );
1298
1299 /// `dst[i] += scale * src[i]` — scalar-broadcast scaled add, in place.
1300 ///
1301 /// MoE per-token combine writes `out[b] += weight_k * expert_k(x[b])`
1302 /// for each top-K expert; this primitive is the per-call accumulate.
1303 /// Backends without a dedicated kernel can fall back to the default
1304 /// implementation, which round-trips through host memory — correct,
1305 /// but slow on a hot path. Override on any backend you actually
1306 /// dispatch MoE on.
1307 fn scaled_add_inplace(
1308 _ctx: &mut Self::Context,
1309 dst: &mut Self::Buffer,
1310 src: &Self::Buffer,
1311 scale: f32,
1312 len: usize,
1313 ) {
1314 let mut dst_v = Self::to_vec(dst, len);
1315 let src_v = Self::to_vec(src, len);
1316 for i in 0..len {
1317 dst_v[i] += scale * src_v[i];
1318 }
1319 // Move the new buffer into the slot pointed to by `dst`. Safe
1320 // because `Self::Buffer: Send + Sync` and the old buffer is
1321 // dropped here when overwritten.
1322 *dst = Self::from_slice(&dst_v);
1323 }
1324
1325 /// Broadcast bias add: `data[r, c] += bias[c]` for every row.
1326 /// Required by Bert / Clip / Whisper whose linear projections carry a bias.
1327 fn add_bias(
1328 ctx: &mut Self::Context,
1329 data: &mut Self::Buffer,
1330 bias: &Self::Buffer,
1331 rows: usize,
1332 cols: usize,
1333 );
1334
1335 /// Full LayerNorm (mean + variance normalisation + affine), distinct from
1336 /// the `rms_norm` used by Llama-family decoders.
1337 /// `out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]`
1338 /// Where `mean` and `var` are reduced over the last dim (cols).
1339 #[allow(clippy::too_many_arguments)]
1340 fn layer_norm(
1341 ctx: &mut Self::Context,
1342 x: &Self::Buffer,
1343 gamma: &Self::Buffer,
1344 beta: &Self::Buffer,
1345 eps: f32,
1346 out: &mut Self::Buffer,
1347 tokens: usize,
1348 dim: usize,
1349 );
1350
1351 /// Element-wise GELU activation (erf-based, matches PyTorch default).
1352 fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
1353
1354 // ── Buffer management (context-free) ────────────────────────────────
1355
1356 fn alloc(len: usize) -> Self::Buffer;
1357 fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
1358 fn from_slice(data: &[f32]) -> Self::Buffer;
1359
1360 /// Load a weight tensor straight from its on-disk byte representation,
1361 /// letting the backend pick its preferred storage dtype.
1362 ///
1363 /// Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching
1364 /// pre-existing loader behaviour. Backends override this to go straight
1365 /// from raw bytes into a native half-precision buffer (e.g. Metal with
1366 /// `FERRUM_METAL_DTYPE=f16`), avoiding the transient 2× RAM spike.
1367 fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
1368 let data = src_dtype.to_f32_vec(raw);
1369 Self::from_slice(&data)
1370 }
1371
1372 // (The Phase A3 unified `gemm_quant(QuantWeights, QuantKind)` stub
1373 // that used to live here is superseded by the `load_quant` /
1374 // `gemm_quant(QuantStore)` pair earlier in this trait — same idea,
1375 // but the store hides the per-kind buffer layout so callers don't
1376 // have to construct a per-kind `QuantWeights<'_, Self>` packet.)
1377
1378 // ── TP collective ops (Phase A3 stubs) ──────────────────────────────
1379 //
1380 // Default impl is single-rank no-op: `world_size = 1`, `rank = 0`, and
1381 // the collective ops are identity. Multi-GPU backends (future
1382 // CudaBackend + NCCL) override these. Model code can call
1383 // `B::all_reduce_sum(...)` unconditionally; single-GPU paths pay zero.
1384
1385 fn world_size(_ctx: &Self::Context) -> usize {
1386 1
1387 }
1388 fn rank(_ctx: &Self::Context) -> usize {
1389 0
1390 }
1391 fn all_reduce(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp) {
1392 // single-rank: no-op
1393 }
1394 fn all_gather(
1395 _ctx: &mut Self::Context,
1396 _local: &Self::Buffer,
1397 _global: &mut Self::Buffer,
1398 _local_len: usize,
1399 ) {
1400 // single-rank: no-op (caller is expected to handle the degenerate
1401 // case or arrange for `local == global`)
1402 }
1403 fn broadcast(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize) {
1404 // single-rank: no-op
1405 }
1406}