Skip to main content

ferrum_kernels/backend/
traits.rs

1//! Core Backend trait — the single abstraction over CUDA / Metal / CPU.
2
3use ferrum_types::{FerrumError, Result};
4use half::{bf16, f16};
5
6/// Source dtype for a weight tensor read straight from safetensors mmap.
7///
8/// Passed to `Backend::from_weight_bytes` so each backend can choose whether
9/// to upcast to its compute dtype or store as-is.
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11pub enum SrcDtype {
12    F32,
13    F16,
14    BF16,
15}
16
17impl SrcDtype {
18    /// Number of bytes per element in the raw on-disk representation.
19    pub const fn bytes_per_elem(self) -> usize {
20        match self {
21            SrcDtype::F32 => 4,
22            SrcDtype::F16 | SrcDtype::BF16 => 2,
23        }
24    }
25
26    /// Materialise the raw byte slice into a `Vec<f32>`. Used by the default
27    /// `Backend::from_weight_bytes` impl; fp16-preferring backends bypass it.
28    pub fn to_f32_vec(self, raw: &[u8]) -> Vec<f32> {
29        match self {
30            SrcDtype::F32 => {
31                debug_assert_eq!(raw.len() % 4, 0);
32                let n = raw.len() / 4;
33                let mut out = vec![0f32; n];
34                for i in 0..n {
35                    let b = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
36                    out[i] = f32::from_le_bytes(b);
37                }
38                out
39            }
40            SrcDtype::F16 => {
41                debug_assert_eq!(raw.len() % 2, 0);
42                let n = raw.len() / 2;
43                let mut out = vec![0f32; n];
44                for i in 0..n {
45                    out[i] = f16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
46                }
47                out
48            }
49            SrcDtype::BF16 => {
50                debug_assert_eq!(raw.len() % 2, 0);
51                let n = raw.len() / 2;
52                let mut out = vec![0f32; n];
53                for i in 0..n {
54                    out[i] = bf16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
55                }
56                out
57            }
58        }
59    }
60}
61
62/// Quantization flavour discriminator for `Backend::gemm_quant`.
63///
64/// Distinct schemes need distinct kernels. Carried as a parameter so the
65/// Backend trait does not explode with one method per quantization type.
66#[derive(Clone, Debug)]
67pub enum QuantKind {
68    /// GPTQ: group-wise int4/int8 with scales + zeros (asymmetric) + optional g_idx.
69    Gptq {
70        bits: u32,
71        group_size: usize,
72        desc_act: bool,
73    },
74    /// AWQ: activation-aware int4 with scales + zeros, different packing from GPTQ.
75    Awq { bits: u32, group_size: usize },
76    /// GGUF: one of k-quants / legacy quants, fully specified by the inner type.
77    Gguf { quant_type: GgufQuantType },
78}
79
80/// GGUF quantization sub-type (expand as kernels are added).
81#[derive(Clone, Copy, Debug)]
82pub enum GgufQuantType {
83    Q4_0,
84    Q4_1,
85    Q4K,
86    Q5K,
87    Q6K,
88    Q8_0,
89}
90
91/// Packed quantized weight buffers passed to `Backend::gemm_quant`.
92///
93/// Not every field is used by every `QuantKind` — e.g. GGUF packs scales
94/// inside `qweight`, so `scales` / `zeros` may be dummies. The Backend
95/// implementation is expected to validate the shape for the kind it handles.
96pub struct QuantWeights<'a, B: Backend> {
97    pub qweight: &'a B::Buffer,
98    pub scales: Option<&'a B::Buffer>,
99    pub zeros: Option<&'a B::Buffer>,
100    pub g_idx: Option<&'a B::Buffer>,
101}
102
103/// Collective-op reduction kind for TP all_reduce.
104#[derive(Clone, Copy, Debug)]
105pub enum ReduceOp {
106    Sum,
107    Max,
108    Min,
109}
110
111/// Configuration for attention dispatch.
112#[derive(Clone, Debug)]
113pub struct AttnConfig {
114    pub num_heads: usize,
115    pub num_kv_heads: usize,
116    pub head_dim: usize,
117    pub causal: bool,
118    pub scale: f32,
119    /// Stride (in rows) between head blocks in the KV buffer.
120    /// `0` means contiguous (use `kv_len`, legacy behaviour).
121    /// Set to `cache_capacity` when flashing against a pre-allocated cache
122    /// that only has `kv_len` valid slots out of `cache_capacity`.
123    pub kv_seq_stride: usize,
124    /// Sliding-window attention size (Mistral v0.1, Gemma).
125    /// `0` = disabled (full causal attention).
126    /// `w > 0` = each query position attends to the previous `w` KV positions
127    ///            (still bounded by `causal` + `pos_offset + qi + 1` as the upper end).
128    pub sliding_window: usize,
129}
130
131impl Default for AttnConfig {
132    fn default() -> Self {
133        Self {
134            num_heads: 0,
135            num_kv_heads: 0,
136            head_dim: 0,
137            causal: false,
138            scale: 1.0,
139            kv_seq_stride: 0,
140            sliding_window: 0,
141        }
142    }
143}
144
145// Note: `TransformerConfig` / `AttnType` / `MlpType` / `RopeConfig` used to
146// live here when `ModelRunner` needed a generic model config. They're now
147// per-model (e.g. `Qwen3Config` in `ferrum-models::models::qwen3`) so each
148// model can carry exactly the architecture parameters it cares about.
149// Backend trait stays model-agnostic.
150
151/// Per-layer KV cache. Each model owns its own `Vec<KvCache<B>>` per sequence.
152///
153/// Two layouts are supported, selected at allocation time:
154/// 1. **Contiguous** (default): `k`/`v` are `[num_kv_heads, capacity, head_dim]`
155///    f32 buffers. `block_size == 0` and `block_table` / `context_lens` are
156///    `None`. Original ferrum layout — used when `FERRUM_METAL_PAGED_KV` is
157///    unset.
158/// 2. **Paged** (vLLM-style): `k`/`v` are `[num_blocks, num_kv_heads,
159///    block_size, head_dim]` block pools. `block_size > 0` and
160///    `block_table` (`u32[max_num_blocks_per_seq]`) + `context_lens`
161///    (`u32[1]` single-seq for now) are populated. Multi-seq sharing
162///    is a Phase 4 concern; today every paged cache_id has its own
163///    pool but the kernel-level indirection works.
164pub struct KvCache<B: Backend> {
165    pub k: B::Buffer,
166    pub v: B::Buffer,
167    pub len: usize,
168    pub capacity: usize,
169    pub num_kv_heads: usize,
170    pub head_dim: usize,
171    /// Paged: KV positions per physical block. `0` ⇒ contiguous layout.
172    pub block_size: usize,
173    /// Paged: `[max_num_blocks_per_seq]` u32 — logical → physical block.
174    pub block_table: Option<B::Buffer>,
175    /// Paged: `[1]` u32 — current context length for the kernel to read.
176    pub context_lens: Option<B::Buffer>,
177    /// Paged: host-side mirror of the physical block indices owned by
178    /// this cache. Lets the model's release path return blocks to the
179    /// shared allocator without reading them back from device.
180    pub paged_block_indices: Vec<u32>,
181}
182
183/// The core abstraction over CUDA / Metal / CPU.
184///
185/// Key design: operations take a `&mut Self::Context` which accumulates work.
186///   - **CPU**: Context is `()` — ops execute immediately.
187///   - **Metal**: Context is a `CommandBuffer` — ops encode into it, flushed on `sync()`.
188///   - **CUDA**: Context is a `CudaStream` — ops launch on the stream, synced on `sync()`.
189///
190/// `layer_forward` passes the context through all ops in a layer.
191/// `ModelRunner` calls `sync()` only when it needs results (e.g., reading logits).
192pub trait Backend: Send + Sync + Sized + 'static {
193    type Buffer: Send + Sync;
194
195    /// Execution context that accumulates GPU work.
196    ///   - CPU: `()` (no-op, ops execute inline)
197    ///   - Metal: wraps a CommandBuffer
198    ///   - CUDA: wraps a CudaStream
199    type Context;
200
201    /// Opaque per-backend GPTQ weight representation.
202    ///   - CPU: dequantized f32 weights (run as regular GEMM)
203    ///   - Metal: `()` — unsupported; `gemm_gptq` errors
204    ///   - CUDA: `MarlinWeight` — pre-repacked tiles + permuted scales
205    ///
206    /// Each backend repacks raw GPTQ tensors (qweight/scales/qzeros, all
207    /// i32/f16) into its preferred format at model load time, so inference
208    /// doesn't pay the repack cost per forward pass.
209    type GptqStore: Send + Sync;
210
211    /// Single backend-specific store for **all GGUF k-quant flavours**
212    /// (Q4_K_M today; Q5_K_M / Q6_K / Q8_0 etc. become enum variants
213    /// without changing the trait shape).
214    ///
215    /// Each backend's `QuantStore` is typically an enum dispatching on
216    /// the on-disk quant type — the public API (`load_quant`,
217    /// `gemm_quant`) takes a [`QuantKind`] discriminator so callers
218    /// don't see the variant boilerplate.
219    ///
220    /// **GPTQ stays on the older [`Self::GptqStore`] path** because its
221    /// load inputs are split arrays (qweight / scales / qzeros), not
222    /// the contiguous byte payload GGUF quants ship as. A future PR can
223    /// fold GPTQ into `QuantStore` once an input-shape unification is
224    /// agreed.
225    type QuantStore: Send + Sync;
226
227    /// Create a new execution context (begin accumulating work).
228    fn new_context() -> Self::Context;
229
230    /// Flush accumulated work and wait for completion.
231    /// CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
232    fn sync(ctx: &mut Self::Context);
233
234    // ── Graph capture / replay (CUDA only) ──────────────────────────────
235    //
236    // Decode-loop optimization: eliminate per-kernel launch overhead by
237    // capturing the full step as a CUDA graph and replaying. CPU/Metal
238    // have no equivalent — defaults return `unsupported`.
239    //
240    // Flow per decode step:
241    //   1. Caller: `set_decode_state(ctx, token, step)` — memcpy to dev bufs
242    //   2. Try `replay_last_graph(ctx)`:
243    //        - Ok(true):  graph replayed, skip eager forward
244    //        - Ok(false): no captured graph yet, run eager
245    //        - Err(_):    not supported, run eager
246    //   3. If running eager and in capture window:
247    //      - `set_dev_state_mode(ctx, true)` so kernels use _dyn variants
248    //      - `begin_graph_capture(ctx)`
249    //      - run forward
250    //      - `end_graph_capture(ctx)` — stores graph on ctx internally
251    //      - `set_dev_state_mode(ctx, false)` — restore scalar kernels
252
253    /// Update per-step dynamic state (token id, step/pos). Fast (3x memcpy).
254    fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) {}
255
256    /// Toggle between scalar-arg kernels (normal) and `_dyn` kernels that
257    /// read their dynamic scalar args from device memory (graph-friendly).
258    fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) {}
259
260    /// Begin stream capture. Subsequent kernel launches are recorded into
261    /// a pending graph instead of executing eagerly.
262    fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
263        Err(FerrumError::unsupported("graph capture not supported"))
264    }
265
266    /// End stream capture and install the captured graph as this context's
267    /// "last graph" for future `replay_last_graph` calls.
268    fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
269        Err(FerrumError::unsupported("graph capture not supported"))
270    }
271
272    /// Replay the last captured graph. Returns `Ok(false)` if no graph
273    /// is cached; caller should run eager.
274    fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool> {
275        Ok(false)
276    }
277
278    /// Drop the cached decode graph — required when the KV cache it
279    /// was captured against is about to be freed (e.g. request release),
280    /// since the graph holds raw device pointers into that cache.
281    fn reset_graph(_ctx: &mut Self::Context) {}
282
283    // ── GPTQ (INT4 quantization) ────────────────────────────────────────
284    //
285    // Two-step: load (once per weight) → gemm (per forward). The store
286    // holds whatever backend-specific format is fastest; caller code
287    // (GptqLinear) is dtype-agnostic.
288
289    /// Repack raw GPTQ tensors into the backend's preferred format.
290    /// Called once per layer at model load time.
291    ///
292    /// Inputs are host-side slices (CPU memory) — the loader reads from
293    /// safetensors and hands them off; each backend uploads + repacks
294    /// per its own strategy. `bits` is typically 4; `group_size` is
295    /// typically 128.
296    #[allow(clippy::too_many_arguments)]
297    fn load_gptq(
298        _qweight: &[i32],
299        _scales: &[f32],
300        _qzeros: &[i32],
301        _g_idx: Option<&[i32]>,
302        _bits: u32,
303        _group_size: usize,
304        _k: usize,
305        _n: usize,
306    ) -> Result<Self::GptqStore> {
307        Err(FerrumError::unsupported(
308            "load_gptq not implemented for this backend",
309        ))
310    }
311
312    /// GEMM with pre-loaded GPTQ weights.
313    /// `out[m, n] = a[m, k] @ dequant(weight)^T`
314    fn gemm_gptq(
315        _ctx: &mut Self::Context,
316        _a: &Self::Buffer,
317        _weight: &Self::GptqStore,
318        _out: &mut Self::Buffer,
319        _m: usize,
320    ) -> Result<()> {
321        Err(FerrumError::unsupported(
322            "gemm_gptq not implemented for this backend",
323        ))
324    }
325
326    /// Load GGUF k-quant weights into the backend's preferred format.
327    ///
328    /// `kind` discriminates Q4_K / Q5_K / Q6_K / Q8_0 etc. The CPU path
329    /// typically eager-dequants to fp32; the Metal path keeps raw block
330    /// bytes in MTLBuffer and dequants per matmul into a transient fp16
331    /// buffer. Adding a new k-quant flavour is a matched pair of
332    /// `QuantStore` variant + `match` arm, not a new trait method.
333    ///
334    /// `bytes`: contiguous on-disk payload — `n_blocks × block_size`.
335    /// `n_rows`: out_features. `n_cols`: in_features. The block count
336    /// is derived per-kind from these dims.
337    fn load_quant(
338        _kind: GgufQuantType,
339        _bytes: &[u8],
340        _n_rows: usize,
341        _n_cols: usize,
342    ) -> Result<Self::QuantStore> {
343        Err(FerrumError::unsupported(
344            "load_quant not implemented for this backend",
345        ))
346    }
347
348    /// Build a fused `QuantStore` from multiple `(kind, bytes, n_rows)`
349    /// parts that share `n_cols`. Used by `GgufLoader::load_fused` when
350    /// parts have heterogeneous quant kinds (e.g. Qwen3 qkv_proj where
351    /// q+k are Q4_K but v is Q6_K) — byte-concatenation isn't possible,
352    /// so each part stays as its own QuantStore and the gemm dispatches
353    /// one matvec per part with output offsets.
354    ///
355    /// Default: not supported. Backends that have a `Fused`-like variant
356    /// override.
357    fn load_quant_fused(
358        _parts: &[(GgufQuantType, &[u8], usize)],
359        _n_cols: usize,
360    ) -> Result<Self::QuantStore> {
361        Err(FerrumError::unsupported(
362            "load_quant_fused not implemented for this backend",
363        ))
364    }
365
366    /// GEMM with k-quant weights. Mirrors `gemm` / `gemm_gptq` shape:
367    /// `out[m, n] = a[m, k] @ dequant(weight)^T`. The dispatch on the
368    /// quant flavour happens inside the backend's `QuantStore` enum.
369    fn gemm_quant(
370        _ctx: &mut Self::Context,
371        _a: &Self::Buffer,
372        _weight: &Self::QuantStore,
373        _out: &mut Self::Buffer,
374        _m: usize,
375    ) -> Result<()> {
376        Err(FerrumError::unsupported(
377            "gemm_quant not implemented for this backend",
378        ))
379    }
380
381    /// Build a stacked-experts `QuantStore` from a contiguous 3-D weight
382    /// payload `[num_experts, n_rows, n_cols/256]` super-blocks.
383    /// Used for the MoE indirect-dispatch fast path; backends without
384    /// such a kernel return `Err(unsupported)` and the model code falls
385    /// back to the per-expert loop.
386    ///
387    /// Default: not supported. Override on backends with batched MoE
388    /// kernels (e.g. Metal `gemv_q*kw_moe_id_f32`).
389    fn load_quant_experts(
390        _kind: GgufQuantType,
391        _bytes: &[u8],
392        _num_experts: usize,
393        _n_rows: usize,
394        _n_cols: usize,
395    ) -> Result<Self::QuantStore> {
396        Err(FerrumError::unsupported(
397            "load_quant_experts not implemented for this backend",
398        ))
399    }
400
401    /// MoE 2-D indirect-dispatch GEMM (prefill m > 1).
402    ///
403    /// Computes per (token, expert_slot) pair, batched across all
404    /// experts in one launch:
405    ///
406    ///   `out[token, slot, :] = a[token, slot_or_0, :] @ dequant(weight[expert(token, slot), :])^T`
407    ///
408    /// `ids[expert][slot] = pair_id` encodes `(token_idx, slot_within_token)`
409    /// so the kernel reads activations indirectly (src1 row for the
410    /// pair) and writes outputs directly to the natural
411    /// `[batch, top_k, M]` layout. `tpe[expert]` gives the count of
412    /// pairs assigned to each expert — threadgroups past `tpe[e]`
413    /// early-exit.
414    ///
415    /// `ne11` selects the src1 inner-batch shape:
416    /// - `1` for `gate` / `up` (broadcast — all slots read the same
417    ///   activation row per token).
418    /// - `top_k` for `down` (per-slot — each pair reads its own row in
419    ///   the upstream silu·gate output).
420    ///
421    /// Closes the prefill MoE gap: the per-token gemv loop becomes one
422    /// batched gemm where each expert's slab handles m ≈ batch·top_k /
423    /// num_experts pairs in parallel via simdgroup_half8x8 matmul.
424    #[allow(clippy::too_many_arguments)]
425    fn gemm_quant_moe_id(
426        _ctx: &mut Self::Context,
427        _a: &Self::Buffer,
428        _weight: &Self::QuantStore,
429        _ids: &Self::Buffer,
430        _tpe: &Self::Buffer,
431        _out: &mut Self::Buffer,
432        _ne11: usize,
433        _top_k: usize,
434        _max_per_expert: usize,
435        _batch: usize,
436    ) -> Result<()> {
437        Err(FerrumError::unsupported(
438            "gemm_quant_moe_id not implemented for this backend",
439        ))
440    }
441
442    /// GPU-side MoE router: `[batch, num_experts]` logits → `[batch, top_k]`
443    /// expert IDs (i32) + `[batch, top_k]` combine weights (f32).
444    ///
445    /// Replaces the per-layer `B::sync + B::to_vec(router_logits) + host route()`
446    /// round trip. The output buffers stay device-side for downstream
447    /// `gemv_quant_moe_id` / `gemm_quant_moe_id` consumption — no host
448    /// pipeline drain in the inner loop.
449    ///
450    /// `norm_topk_prob`: if true, divide each row's K weights by their
451    /// sum so they total 1.0 (Qwen3-MoE / Mixtral default).
452    #[allow(clippy::too_many_arguments)]
453    fn route_topk_softmax(
454        _ctx: &mut Self::Context,
455        _logits: &Self::Buffer,
456        _out_ids: &mut Self::Buffer,
457        _out_weights: &mut Self::Buffer,
458        _batch: usize,
459        _num_experts: usize,
460        _top_k: usize,
461        _norm_topk_prob: bool,
462    ) -> Result<()> {
463        Err(FerrumError::unsupported(
464            "route_topk_softmax not implemented for this backend",
465        ))
466    }
467
468    /// GPU-side bucket sort: turn `[batch, top_k]` selected expert IDs
469    /// (from [`Self::route_topk_softmax`]) into `tpe[num_experts]` /
470    /// `ids[num_experts * row_stride]` arrays consumed by the batched
471    /// MoE GEMM, and emit indirect-dispatch args for the consumer GEMM.
472    ///
473    /// The `ids` buffer's row stride is `batch * top_k` (worst case);
474    /// only the first `tpe[e]` entries of each row are populated. The
475    /// consumer GEMM kernel early-exits at `r1 >= tpe[e]`, so the over-
476    /// strided indices cost nothing in the inner loop. The grid size,
477    /// however, would still be worst-case unless we tighten it — this
478    /// is what the `gate_up_args` / `down_args` outputs do: a 12-byte
479    /// `(grid_x, grid_y, grid_z)` u32 triple per shape, ready for
480    /// `dispatch_thread_groups_indirect`. `grid_x` is shared (depends
481    /// only on `max(tpe[e])`); `grid_y` differs because gate/up has
482    /// `M = m_gate_up` while down has `M = m_down`.
483    ///
484    /// All five output buffers are written in one kernel; no host
485    /// roundtrip and no per-layer pipeline drain.
486    #[allow(clippy::too_many_arguments)]
487    fn compute_ids_tpe_gpu(
488        _ctx: &mut Self::Context,
489        _selected_ids: &Self::Buffer,
490        _tpe: &mut Self::Buffer,
491        _ids: &mut Self::Buffer,
492        _gate_up_args: &mut Self::Buffer,
493        _down_args: &mut Self::Buffer,
494        _batch: usize,
495        _num_experts: usize,
496        _top_k: usize,
497        _m_gate_up: usize,
498        _m_down: usize,
499    ) -> Result<()> {
500        Err(FerrumError::unsupported(
501            "compute_ids_tpe_gpu not implemented for this backend",
502        ))
503    }
504
505    /// Indirect-dispatch variant of `gemm_quant_moe_id`.
506    ///
507    /// Identical inputs except the grid is read from `args_buf` (a 12-
508    /// byte u32 triple written by `compute_ids_tpe_gpu`) instead of
509    /// being computed from `max_per_expert`. `max_per_expert` is still
510    /// the kernel parameter used as the row stride for `ids` indexing
511    /// (= `batch * top_k`, worst case); only the dispatched grid
512    /// shrinks to cover `max(tpe[e])` columns.
513    #[allow(clippy::too_many_arguments)]
514    fn gemm_quant_moe_id_indirect(
515        _ctx: &mut Self::Context,
516        _src1: &Self::Buffer,
517        _weights: &Self::QuantStore,
518        _ids: &Self::Buffer,
519        _tpe: &Self::Buffer,
520        _out: &mut Self::Buffer,
521        _args_buf: &Self::Buffer,
522        _ne11: usize,
523        _top_k: usize,
524        _max_per_expert: usize,
525        _batch: usize,
526    ) -> Result<()> {
527        Err(FerrumError::unsupported(
528            "gemm_quant_moe_id_indirect not implemented for this backend",
529        ))
530    }
531
532    /// Stacked SiLU·gate over `[batch * top_k, ffn]` rows (prefill version
533    /// of `silu_mul_stacked`).
534    fn silu_mul_batched(
535        _ctx: &mut Self::Context,
536        _gate: &Self::Buffer,
537        _up: &Self::Buffer,
538        _out: &mut Self::Buffer,
539        _total_pairs: usize,
540        _ffn: usize,
541    ) -> Result<()> {
542        Err(FerrumError::unsupported(
543            "silu_mul_batched not implemented for this backend",
544        ))
545    }
546
547    /// Fused weighted-sum + residual-add: `residual[i] += Σ_k weights[k] · slots[k, i]`.
548    /// Single dispatch replaces the (weighted_sum → moe_out) +
549    /// (add_inplace residual += moe_out) pair on the decode hot path.
550    fn weighted_sum_residual_stacked(
551        _ctx: &mut Self::Context,
552        _slots: &Self::Buffer,
553        _weights: &Self::Buffer,
554        _residual: &mut Self::Buffer,
555        _n_slots: usize,
556        _hidden: usize,
557    ) -> Result<()> {
558        Err(FerrumError::unsupported(
559            "weighted_sum_residual_stacked not implemented for this backend",
560        ))
561    }
562
563    /// Fused weighted-sum-residual + RMSNorm: combines this layer's
564    /// `weighted_sum_residual_stacked` with the next layer's leading
565    /// `rms_norm` into a single dispatch.
566    ///
567    /// Computes
568    ///   `residual[i] += Σ_s w[s] · slots[s, i]`
569    ///   `normed_out[i] = residual[i] · (1 / sqrt(Σ residual² / hidden + eps)) · next_norm_w[i]`
570    ///
571    /// Caller is responsible for skipping the next layer's standalone
572    /// `rms_norm` — `normed_out` IS that layer's `norm_out` input.
573    /// Default returns Unsupported.
574    #[allow(clippy::too_many_arguments)]
575    fn weighted_sum_residual_norm_stacked(
576        _ctx: &mut Self::Context,
577        _slots: &Self::Buffer,
578        _weights: &Self::Buffer,
579        _residual: &mut Self::Buffer,
580        _next_norm_w: &Self::Buffer,
581        _normed_out: &mut Self::Buffer,
582        _n_slots: usize,
583        _hidden: usize,
584        _eps: f32,
585    ) -> Result<()> {
586        Err(FerrumError::unsupported(
587            "weighted_sum_residual_norm_stacked not implemented for this backend",
588        ))
589    }
590
591    /// Per-batch weighted sum: `out[b, h] = Σ_k weights[b, k] · slots[b, k, h]`.
592    /// Single dispatch covers the whole batch (prefill version of
593    /// `weighted_sum_stacked` which only handled one token).
594    fn weighted_sum_batched(
595        _ctx: &mut Self::Context,
596        _slots: &Self::Buffer,
597        _weights: &Self::Buffer,
598        _out: &mut Self::Buffer,
599        _batch: usize,
600        _top_k: usize,
601        _hidden: usize,
602    ) -> Result<()> {
603        Err(FerrumError::unsupported(
604            "weighted_sum_batched not implemented for this backend",
605        ))
606    }
607
608    /// Offset-aware variant of [`Self::weighted_sum_batched`] —
609    /// `weights` reads from `weights_offset` (in elements, points at
610    /// the start of `[batch, top_k]`), `out` writes from `out_offset`
611    /// (in elements, points at start of `[batch, hidden]`). Used by
612    /// the per-item batched-decode path to skip `copy_slice` round-trips.
613    /// Default falls back to the non-offset variant via two copies.
614    #[allow(clippy::too_many_arguments)]
615    fn weighted_sum_batched_offset(
616        ctx: &mut Self::Context,
617        slots: &Self::Buffer,
618        weights: &Self::Buffer,
619        weights_offset: usize,
620        out: &mut Self::Buffer,
621        out_offset: usize,
622        batch: usize,
623        top_k: usize,
624        hidden: usize,
625    ) -> Result<()> {
626        // Default: stage through scratch — backends override for zero-copy.
627        let _ = (
628            ctx,
629            slots,
630            weights,
631            weights_offset,
632            out,
633            out_offset,
634            batch,
635            top_k,
636            hidden,
637        );
638        Err(FerrumError::unsupported(
639            "weighted_sum_batched_offset not implemented for this backend",
640        ))
641    }
642
643    /// MoE indirect-dispatch GEMV: `out[i, :] = a[i, :] @ dequant(weight[ids[i], :])^T`
644    /// for each `i ∈ [0, n_selected)`. Single backend dispatch covers
645    /// all selected (token, expert) pairs.
646    ///
647    /// `weight` must be a stacked-experts variant produced by
648    /// [`Self::load_quant_experts`]. `ids` is a backend-side buffer of
649    /// `n_selected` i32 expert IDs. `out` is sized `[n_selected, n_rows]`.
650    /// `src1_stride` is the per-slot activation stride in **elements**:
651    /// `0` ⇒ every slot reads the same activation row (broadcast — for
652    /// `gate` / `up` projections); `n_cols` ⇒ each slot reads its own
653    /// activation row (for `down` projections, where each expert
654    /// consumes its own silu(gate)·up output).
655    fn gemv_quant_moe_id(
656        _ctx: &mut Self::Context,
657        _a: &Self::Buffer,
658        _weight: &Self::QuantStore,
659        _ids: &Self::Buffer,
660        _out: &mut Self::Buffer,
661        _n_selected: usize,
662        _src1_stride: usize,
663    ) -> Result<()> {
664        Err(FerrumError::unsupported(
665            "gemv_quant_moe_id not implemented for this backend",
666        ))
667    }
668
669    /// Offset-aware variant of [`Self::gemv_quant_moe_id`] — reads `a`
670    /// from `a_offset` (in elements; meaningful only when src1_stride=0
671    /// for the broadcast case, or as the start of an `n_selected × K`
672    /// strided read when src1_stride≥K), reads `ids` from `ids_offset`
673    /// (the i-th `top_k` block in a stacked-batch `[M, top_k]` ids
674    /// buffer), and writes `out` from offset 0 (output stays per-iter
675    /// scratch). Used by the per-item batched-decode path so the M=N
676    /// concurrent decodes can read directly from the M-batch
677    /// `selected_ids_buf` / `norm_out` without materialising
678    /// per-iteration copies.
679    #[allow(clippy::too_many_arguments)]
680    fn gemv_quant_moe_id_offset(
681        ctx: &mut Self::Context,
682        a: &Self::Buffer,
683        a_offset: usize,
684        weight: &Self::QuantStore,
685        ids: &Self::Buffer,
686        ids_offset: usize,
687        out: &mut Self::Buffer,
688        n_selected: usize,
689        src1_stride: usize,
690    ) -> Result<()> {
691        let _ = (
692            ctx,
693            a,
694            a_offset,
695            weight,
696            ids,
697            ids_offset,
698            out,
699            n_selected,
700            src1_stride,
701        );
702        Err(FerrumError::unsupported(
703            "gemv_quant_moe_id_offset not implemented for this backend",
704        ))
705    }
706
707    /// Allocate a backend buffer of i32-typed values for kernels that
708    /// need integer indices (MoE expert IDs, scatter indices, etc.).
709    ///
710    /// Default impl bit-casts the i32s to f32s and uploads via
711    /// `from_slice` — useful on backends where the buffer type is type-
712    /// erased (CPU's `Vec<f32>`, Metal's untyped MTLBuffer). Backends
713    /// that use a strongly-typed buffer override.
714    fn from_slice_i32(data: &[i32]) -> Self::Buffer {
715        let f: Vec<f32> = data.iter().map(|&i| f32::from_bits(i as u32)).collect();
716        Self::from_slice(&f)
717    }
718
719    /// Overwrite an existing i32 buffer's contents in place. Used on
720    /// the MoE decode hot path: per-layer expert-id updates do an
721    /// in-place memcpy instead of allocating a fresh device buffer
722    /// (48 layers × 128 tokens = 6144 fresh allocations per decode
723    /// run otherwise — allocator pressure dominates the secondary cost).
724    ///
725    /// Default impl falls back to `from_slice_i32` + drop. Backends
726    /// with shared CPU↔GPU memory (Metal `StorageModeShared`, CPU's
727    /// `Vec<f32>`) override with a direct write.
728    fn write_i32_into(buf: &mut Self::Buffer, data: &[i32]) {
729        *buf = Self::from_slice_i32(data);
730    }
731
732    /// Overwrite an existing f32 buffer's contents in place. Counterpart
733    /// to `write_i32_into` for f32 data — used to update the per-token
734    /// MoE combine weights into a pre-allocated scratch buffer instead
735    /// of allocating a fresh `from_slice` buffer 6144 times per decode
736    /// run.
737    fn write_f32_into(buf: &mut Self::Buffer, data: &[f32]) {
738        *buf = Self::from_slice(data);
739    }
740
741    /// Stacked SiLU·gate over `[n_slots, ffn]` rows.
742    ///
743    /// Computes `out[s, i] = silu(gate[s, i]) * up[s, i]` for each slot
744    /// `s`, element `i`. Single dispatch covers all slots — cuts the
745    /// MoE decode silu staging from `top_k * (3 copy_slice + 1 silu)`
746    /// = 32 dispatches per layer to 1.
747    fn silu_mul_stacked(
748        _ctx: &mut Self::Context,
749        _gate: &Self::Buffer,
750        _up: &Self::Buffer,
751        _out: &mut Self::Buffer,
752        _n_slots: usize,
753        _ffn: usize,
754    ) -> Result<()> {
755        Err(FerrumError::unsupported(
756            "silu_mul_stacked not implemented for this backend",
757        ))
758    }
759
760    /// Fused gate+up MoE GEMV with in-register `SiLU(gate) * up`.
761    ///
762    /// Folds the three back-to-back dispatches that the stacked MoE
763    /// FFN decode path emitted per layer:
764    ///   1. `gemv_quant_moe_id` (gate) → gate_out_stacked
765    ///   2. `gemv_quant_moe_id` (up)   → up_out_stacked
766    ///   3. `silu_mul_stacked`         → silu_stacked
767    /// into a single dispatch that writes `silu_stacked` directly.
768    /// Saves 2 dispatches per layer plus the entire round-trip through
769    /// the gate_out / up_out scratch buffers (≈4× `[top_k, ffn]` of
770    /// intermediate traffic). The activation read is also halved
771    /// because the inner Q4_K reduction reuses one register-file load
772    /// across both weight matrices.
773    ///
774    /// Both `gate_w` and `up_w` must be `Q4KExperts` stacks with
775    /// matching `(num_experts, n_rows, n_cols)` (true for Qwen3-MoE
776    /// GGUFs). Backends without the fused kernel can fall back to the
777    /// 3-dispatch path; callers should gate via
778    /// [`Self::supports_fused_moe_gate_up_silu`] to avoid the
779    /// `Unsupported` String-allocating error round trip on the decode
780    /// hot path.
781    #[allow(clippy::too_many_arguments)]
782    fn gemv_quant_moe_id_gate_up_silu(
783        _ctx: &mut Self::Context,
784        _a: &Self::Buffer,
785        _gate_w: &Self::QuantStore,
786        _up_w: &Self::QuantStore,
787        _ids: &Self::Buffer,
788        _silu_out: &mut Self::Buffer,
789        _n_selected: usize,
790    ) -> Result<()> {
791        Err(FerrumError::unsupported(
792            "gemv_quant_moe_id_gate_up_silu not implemented for this backend",
793        ))
794    }
795
796    /// Capability probe for [`Self::gemv_quant_moe_id_gate_up_silu`].
797    ///
798    /// `true` ⇒ the fused kernel is wired in and the caller should
799    /// prefer it on the MoE decode hot path. `false` ⇒ caller must use
800    /// the 3-dispatch fallback (gate gemv + up gemv + silu_mul_stacked).
801    /// Lets callers branch without paying the cost of an `Err(Unsupported)`
802    /// allocation per (layer, step).
803    fn supports_fused_moe_gate_up_silu() -> bool {
804        false
805    }
806
807    /// Batched MoE indirect-dispatch GEMV — one Metal launch covers
808    /// **all** `m * top_k` (token, expert) pairs at once.
809    ///
810    /// This is the symmetric counterpart of
811    /// [`Self::gemv_quant_moe_id`]: same Q4_K decode loop, same
812    /// per-pair output, but the grid Z-axis spans `m * top_k` instead
813    /// of just `top_k`. Eliminates the engine-level per-token outer
814    /// loop that emits ~16× the dispatches llama.cpp emits at c=16
815    /// (their `kernel_mul_mv_id` already handles the M batch in one
816    /// dispatch).
817    ///
818    /// `a`           : activation buffer; pair `p` reads
819    ///                 `(p / top_k) * src1_outer_stride
820    ///                  + (p % top_k) * src1_inner_stride` floats.
821    ///                 gate / up:  src1 = `norm_out [m, K]`,
822    ///                              outer = K, inner = 0
823    ///                              (slots within a token broadcast).
824    ///                 down:       src1 = `silu_stacked [m, top_k, K]`,
825    ///                              outer = top_k * K, inner = K.
826    /// `weight`      : Q4KExperts stacked weights, common across
827    ///                 selected experts.
828    /// `ids`         : flat `[m * top_k]` selected-expert IDs (i32).
829    /// `out`         : `[m * top_k, n_rows]` outputs.
830    /// `m`           : token batch size.
831    /// `top_k`       : selected experts per token.
832    /// `src1_outer_stride`, `src1_inner_stride`: in **floats**.
833    #[allow(clippy::too_many_arguments)]
834    fn gemv_quant_moe_id_batched(
835        _ctx: &mut Self::Context,
836        _a: &Self::Buffer,
837        _weight: &Self::QuantStore,
838        _ids: &Self::Buffer,
839        _out: &mut Self::Buffer,
840        _m: usize,
841        _top_k: usize,
842        _src1_outer_stride: usize,
843        _src1_inner_stride: usize,
844    ) -> Result<()> {
845        Err(FerrumError::unsupported(
846            "gemv_quant_moe_id_batched not implemented for this backend",
847        ))
848    }
849
850    /// Capability probe for [`Self::gemv_quant_moe_id_batched`].
851    fn supports_batched_moe_gemv() -> bool {
852        false
853    }
854
855    /// Whether this backend has a paged-KV decode path
856    /// (`paged_decode_attention` etc.). Currently true for Metal, false
857    /// for CPU. Used to decide the default of `FERRUM_METAL_PAGED_KV` —
858    /// the `serve` path should opt in automatically when supported so
859    /// users get the bench-quality concurrent-decode numbers without
860    /// having to learn the flag.
861    fn supports_paged_kv() -> bool {
862        false
863    }
864
865    /// Batched fused gate+up MoE GEMV with in-register `SiLU(gate) * up`.
866    ///
867    /// Counterpart of [`Self::gemv_quant_moe_id_gate_up_silu`] for the
868    /// batched-decode path: same in-register fusion, but the grid Z
869    /// dimension covers all `m * top_k` (token, expert) pairs in one
870    /// dispatch. Folds the three batched MoE FFN dispatches per layer
871    /// (gate gemv + up gemv + silu_mul_batched) into one — the missing
872    /// fusion that left the m≥2 batched-decode path slower than the
873    /// per-token loop (which already had this fusion at m=1).
874    ///
875    /// Both `gate_w` and `up_w` must be `Q4KExperts` stacks with
876    /// matching `(num_experts, n_rows, n_cols)`.
877    #[allow(clippy::too_many_arguments)]
878    fn gemv_quant_moe_id_gate_up_silu_batched(
879        _ctx: &mut Self::Context,
880        _a: &Self::Buffer,
881        _gate_w: &Self::QuantStore,
882        _up_w: &Self::QuantStore,
883        _ids: &Self::Buffer,
884        _silu_out: &mut Self::Buffer,
885        _m: usize,
886        _top_k: usize,
887        _src1_outer_stride: usize,
888        _src1_inner_stride: usize,
889    ) -> Result<()> {
890        Err(FerrumError::unsupported(
891            "gemv_quant_moe_id_gate_up_silu_batched not implemented for this backend",
892        ))
893    }
894
895    /// Capability probe for [`Self::gemv_quant_moe_id_gate_up_silu_batched`].
896    fn supports_batched_moe_gate_up_silu() -> bool {
897        false
898    }
899
900    /// Weighted sum across `n_slots` rows of `[hidden]`.
901    ///
902    /// Computes `out[i] = Σ_s weights[s] * slots[s, i]`. Single
903    /// dispatch replaces the per-slot `(copy_slice + scaled_add)`
904    /// loop in the MoE decode path (16 dispatches per layer → 1).
905    fn weighted_sum_stacked(
906        _ctx: &mut Self::Context,
907        _slots: &Self::Buffer,
908        _weights: &Self::Buffer,
909        _out: &mut Self::Buffer,
910        _n_slots: usize,
911        _hidden: usize,
912    ) -> Result<()> {
913        Err(FerrumError::unsupported(
914            "weighted_sum_stacked not implemented for this backend",
915        ))
916    }
917
918    // ── GEMM ────────────────────────────────────────────────────────────
919
920    fn gemm(
921        ctx: &mut Self::Context,
922        a: &Self::Buffer,
923        b: &Self::Buffer,
924        out: &mut Self::Buffer,
925        m: usize,
926        n: usize,
927        k: usize,
928    );
929
930    // ── Norms ───────────────────────────────────────────────────────────
931
932    fn rms_norm(
933        ctx: &mut Self::Context,
934        x: &Self::Buffer,
935        w: &Self::Buffer,
936        eps: f32,
937        out: &mut Self::Buffer,
938        tokens: usize,
939        dim: usize,
940    );
941
942    fn fused_add_rms_norm(
943        ctx: &mut Self::Context,
944        residual: &mut Self::Buffer,
945        x: &Self::Buffer,
946        w: &Self::Buffer,
947        eps: f32,
948        out: &mut Self::Buffer,
949        tokens: usize,
950        dim: usize,
951    );
952
953    // ── Attention ───────────────────────────────────────────────────────
954
955    fn flash_attention(
956        ctx: &mut Self::Context,
957        q: &Self::Buffer,
958        k: &Self::Buffer,
959        v: &Self::Buffer,
960        out: &mut Self::Buffer,
961        batch: usize,
962        q_len: usize,
963        kv_len: usize,
964        pos_offset: usize,
965        cfg: &AttnConfig,
966    );
967
968    /// Multi-Head Latent Attention — DeepSeek V2 / V3's compressed-KV
969    /// attention variant. Extension point only; no backend implements it
970    /// yet. DeepSeek V3 landing in Phase D/E will fill this in.
971    ///
972    /// `q`: full Q `[batch, num_heads, q_len, head_dim]`
973    /// `kv_compressed`: latent KV `[batch, kv_len, kv_lora_rank]`
974    /// `kv_rope`: per-position rope-applied key heads `[batch, kv_len, qk_rope_head_dim]`
975    /// `out`: `[batch, num_heads, q_len, head_dim]`
976    #[allow(clippy::too_many_arguments)]
977    fn mla_attention(
978        _ctx: &mut Self::Context,
979        _q: &Self::Buffer,
980        _kv_compressed: &Self::Buffer,
981        _kv_rope: &Self::Buffer,
982        _out: &mut Self::Buffer,
983        _batch: usize,
984        _q_len: usize,
985        _kv_len: usize,
986        _pos_offset: usize,
987        _cfg: &AttnConfig,
988        _kv_lora_rank: usize,
989        _qk_rope_head_dim: usize,
990    ) -> Result<()> {
991        Err(FerrumError::unsupported(
992            "mla_attention not implemented for this backend; required by \
993             DeepSeek V2/V3 (Phase D/E)",
994        ))
995    }
996
997    // ── Element-wise ────────────────────────────────────────────────────
998    //
999    // Models use `add_inplace` for residual updates and `copy_slice` for the
1000    // row-extraction step in prefill. Offset-free copy / non-inplace add are
1001    // not needed by the current Model-as-Code path; they can return later if
1002    // a model actually requires them.
1003
1004    /// Copy `len` floats from `src[src_offset..]` to `dst[dst_offset..]`.
1005    ///
1006    /// Needed for Qwen3Model::prefill to pluck the last token's hidden state
1007    /// out of `residual[seq_len, h]` without round-tripping through host RAM.
1008    /// `Backend::copy` is the offset-free variant; `copy_slice` additionally
1009    /// supports non-zero source and destination offsets.
1010    fn copy_slice(
1011        ctx: &mut Self::Context,
1012        src: &Self::Buffer,
1013        src_offset: usize,
1014        dst: &mut Self::Buffer,
1015        dst_offset: usize,
1016        len: usize,
1017    );
1018
1019    // ── Embedding ───────────────────────────────────────────────────────
1020
1021    fn embedding_lookup(
1022        ctx: &mut Self::Context,
1023        table: &Self::Buffer,
1024        ids: &[u32],
1025        out: &mut Self::Buffer,
1026        dim: usize,
1027    );
1028
1029    // ── Transformer-specific fused ops ─────────────────────────────────
1030    // These avoid CPU round-trips for data layout transformations.
1031
1032    /// Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers.
1033    /// Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
1034    fn split_qkv(
1035        ctx: &mut Self::Context,
1036        qkv: &Self::Buffer,
1037        q: &mut Self::Buffer,
1038        k: &mut Self::Buffer,
1039        v: &mut Self::Buffer,
1040        tokens: usize,
1041        q_dim: usize,
1042        kv_dim: usize,
1043    );
1044
1045    /// Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im],
1046    /// then compute SiLU(gate) * up → out [tokens, im].
1047    fn fused_silu_mul_split(
1048        ctx: &mut Self::Context,
1049        gate_up: &Self::Buffer,
1050        out: &mut Self::Buffer,
1051        tokens: usize,
1052        im: usize,
1053    );
1054
1055    /// Fused QK-norm + RoPE + transpose-to-head-major.
1056    ///
1057    /// `mode` selects the operation:
1058    ///   0 = transpose only (typical for V, which needs no norm and no RoPE)
1059    ///   1 = per-head RMS norm + RoPE + transpose  (Q/K with QK-norm, Qwen3)
1060    ///   2 = RoPE + transpose                       (Q/K without QK-norm, Llama/Mistral)
1061    ///
1062    /// input:   `[tokens, heads, head_dim]`  (token-major, output of split_qkv)
1063    /// output:  `[heads, tokens, head_dim]`  (head-major, ready for flash_attn / kv_cache_append)
1064    ///
1065    /// `pos_offset` is the position of token 0 (decode uses current seq len;
1066    /// prefill uses 0). Within the batch, positions are taken as `pos_offset + i`.
1067    ///
1068    /// This is the primary attention-input preparation op. Backends that have a
1069    /// fused kernel (Metal's `qk_norm_rope_transpose_f32`) will be dramatically
1070    /// faster than composing norm + rope + transpose separately; the CPU
1071    /// fallback lowers to the individual ops.
1072    #[allow(clippy::too_many_arguments)]
1073    fn qk_norm_rope(
1074        ctx: &mut Self::Context,
1075        input: &Self::Buffer,
1076        norm_w: &Self::Buffer,
1077        cos: &Self::Buffer,
1078        sin: &Self::Buffer,
1079        output: &mut Self::Buffer,
1080        tokens: usize,
1081        heads: usize,
1082        head_dim: usize,
1083        pos_offset: usize,
1084        eps: f32,
1085        mode: i32,
1086    );
1087
1088    /// Fused split-QKV + QK-norm + RoPE + head-major transpose.
1089    ///
1090    /// Single-dispatch replacement for the (`split_qkv` → 3× `qk_norm_rope`)
1091    /// chain on the decode-attention prelude. Reads the linear-layer
1092    /// fused-QKV output once and writes head-major Q/K/V directly into
1093    /// attention scratch.
1094    ///
1095    /// `qkv` layout: `[tokens, q_heads*hd + 2*kv_heads*hd]`.
1096    /// `q_out`: `[q_heads, tokens, hd]`. `k_out`/`v_out`: `[kv_heads, tokens, hd]`.
1097    /// `qk_mode`: 1 = norm + RoPE for Q/K (Qwen3 with QK-norm),
1098    ///            2 = RoPE only for Q/K (no QK-norm; Llama-style).
1099    /// V always falls through to transpose-only.
1100    ///
1101    /// Default returns Unsupported. Backends that implement it are
1102    /// expected to be dramatically faster than the four-dispatch chain.
1103    #[allow(clippy::too_many_arguments)]
1104    fn split_qkv_norm_rope(
1105        _ctx: &mut Self::Context,
1106        _qkv: &Self::Buffer,
1107        _q_norm_w: &Self::Buffer,
1108        _k_norm_w: &Self::Buffer,
1109        _cos: &Self::Buffer,
1110        _sin: &Self::Buffer,
1111        _q_out: &mut Self::Buffer,
1112        _k_out: &mut Self::Buffer,
1113        _v_out: &mut Self::Buffer,
1114        _tokens: usize,
1115        _q_heads: usize,
1116        _kv_heads: usize,
1117        _head_dim: usize,
1118        _pos_offset: usize,
1119        _eps: f32,
1120        _qk_mode: i32,
1121    ) -> Result<()> {
1122        Err(FerrumError::unsupported(
1123            "split_qkv_norm_rope not implemented for this backend",
1124        ))
1125    }
1126
1127    /// Variant of [`Backend::split_qkv_norm_rope`] that writes the new
1128    /// K and V directly into pre-allocated head-major KV cache buffers
1129    /// at slot `[kv_heads, cache_len .. cache_len + tokens, hd]`.
1130    /// Eliminates the trailing `kv_cache_append_head_major` dispatch on
1131    /// the decode hot path. Q still lands in per-token head-major
1132    /// scratch (flash-attention reads it as the query).
1133    ///
1134    /// Default returns Unsupported. Backends without the fused kernel
1135    /// can keep using `split_qkv_norm_rope` + `kv_cache_append_head_major`.
1136    #[allow(clippy::too_many_arguments)]
1137    fn split_qkv_norm_rope_into_cache(
1138        _ctx: &mut Self::Context,
1139        _qkv: &Self::Buffer,
1140        _q_norm_w: &Self::Buffer,
1141        _k_norm_w: &Self::Buffer,
1142        _cos: &Self::Buffer,
1143        _sin: &Self::Buffer,
1144        _q_out: &mut Self::Buffer,
1145        _cache_k: &mut Self::Buffer,
1146        _cache_v: &mut Self::Buffer,
1147        _tokens: usize,
1148        _q_heads: usize,
1149        _kv_heads: usize,
1150        _head_dim: usize,
1151        _pos_offset: usize,
1152        _eps: f32,
1153        _qk_mode: i32,
1154        _cache_len: usize,
1155        _cache_capacity: usize,
1156    ) -> Result<()> {
1157        Err(FerrumError::unsupported(
1158            "split_qkv_norm_rope_into_cache not implemented for this backend",
1159        ))
1160    }
1161
1162    /// Paged-KV variant of [`Self::split_qkv_norm_rope_into_cache`].
1163    ///
1164    /// Same fused split + qk-norm + RoPE, but K/V are written into a
1165    /// paged pool `[num_blocks, kv_heads, block_size, head_dim]`
1166    /// indexed via `block_table[logical_block]` → physical_block.
1167    /// Q still goes to head-major scratch.
1168    ///
1169    /// Default returns Unsupported. Backends that lack a paged kernel
1170    /// keep using the contiguous variant.
1171    /// `qkv_byte_offset` / `q_out_byte_offset` let the caller pass a
1172    /// slice of a larger batched buffer (used by the multi-seq paged
1173    /// path in `decode_batch_internal`). For single-seq dispatch they
1174    /// should be 0.
1175    #[allow(clippy::too_many_arguments)]
1176    fn split_qkv_norm_rope_into_paged_cache(
1177        _ctx: &mut Self::Context,
1178        _qkv: &Self::Buffer,
1179        _qkv_byte_offset: u64,
1180        _q_norm_w: &Self::Buffer,
1181        _k_norm_w: &Self::Buffer,
1182        _cos: &Self::Buffer,
1183        _sin: &Self::Buffer,
1184        _q_out: &mut Self::Buffer,
1185        _q_out_byte_offset: u64,
1186        _cache_k: &mut Self::Buffer,
1187        _cache_v: &mut Self::Buffer,
1188        _block_table: &Self::Buffer,
1189        _tokens: usize,
1190        _q_heads: usize,
1191        _kv_heads: usize,
1192        _head_dim: usize,
1193        _pos_offset: usize,
1194        _eps: f32,
1195        _qk_mode: i32,
1196        _cache_len: usize,
1197        _block_size: usize,
1198        _max_num_blocks_per_seq: usize,
1199    ) -> Result<()> {
1200        Err(FerrumError::unsupported(
1201            "split_qkv_norm_rope_into_paged_cache not implemented for this backend",
1202        ))
1203    }
1204
1205    /// Paged-KV variant of [`Self::flash_attention`].
1206    ///
1207    /// Decode (`q_len == 1`):
1208    ///   `q`/`out`: `[num_seqs, num_heads, head_dim]` (token-major)
1209    ///
1210    /// Causal prefill (`q_len > 1`, single seq):
1211    ///   `q`/`out`: `[num_heads, q_len, head_dim]` (head-major — the
1212    ///              layout produced by `split_qkv_norm_rope_into_paged_cache`)
1213    ///   The kernel applies a per-q-token causal mask using
1214    ///   `context_lens[seq]` as the FINAL kv_len (= `pos_offset + q_len`):
1215    ///   token i sees positions `[0, context_lens - q_len + 1 + i)`.
1216    ///
1217    /// Common to both:
1218    ///   `k_pool`/`v_pool`: `[num_blocks, num_kv_heads, block_size, head_dim]`
1219    ///   `block_tables`: `[num_seqs, max_num_blocks_per_seq]` u32
1220    ///   `context_lens`: `[num_seqs]` u32
1221    ///
1222    /// Backends without a paged kernel return Unsupported; callers are
1223    /// expected to fall back to contiguous KV.
1224    #[allow(clippy::too_many_arguments)]
1225    fn paged_decode_attention(
1226        _ctx: &mut Self::Context,
1227        _q: &Self::Buffer,
1228        _k_pool: &Self::Buffer,
1229        _v_pool: &Self::Buffer,
1230        _out: &mut Self::Buffer,
1231        _block_tables: &Self::Buffer,
1232        _context_lens: &Self::Buffer,
1233        _num_seqs: usize,
1234        _num_heads: usize,
1235        _num_kv_heads: usize,
1236        _head_dim: usize,
1237        _block_size: usize,
1238        _max_num_blocks_per_seq: usize,
1239        _q_len: usize,
1240    ) -> Result<()> {
1241        Err(FerrumError::unsupported(
1242            "paged_decode_attention not implemented for this backend",
1243        ))
1244    }
1245
1246    /// Allocate a u32 buffer of length `n` for paged-KV bookkeeping
1247    /// (block tables, context lens). Default uses the existing
1248    /// `from_slice_i32` route then bit-casts; backends with a faster
1249    /// path can override.
1250    fn alloc_u32(n: usize) -> Self::Buffer {
1251        // Reinterpret as i32 — same 4-byte word; the kernel reads
1252        // bytes via `device const uint32_t *`.
1253        Self::from_slice_i32(&vec![0i32; n])
1254    }
1255
1256    /// Write a u32 slice into a buffer previously allocated via
1257    /// [`Self::alloc_u32`]. Used for live block_tables / context_lens
1258    /// updates between decode steps.
1259    ///
1260    /// Default: reads back, mutates host-side, writes back. Metal
1261    /// backend overrides with a direct memcpy on the StorageModeShared
1262    /// buffer.
1263    fn write_u32(_ctx: &mut Self::Context, _dst: &mut Self::Buffer, _data: &[u32]) {
1264        // No-op default — most backends won't exercise this path until
1265        // they implement paged_decode_attention.
1266    }
1267
1268    /// Append new K/V into a pre-allocated head-major cache buffer.
1269    ///
1270    /// `cache_k` / `cache_v`: `[nkv, capacity, hd]` (head-major, pre-allocated)
1271    /// `new_k_head_major` / `new_v_head_major`: `[nkv, new_tokens, hd]`
1272    ///   — produced directly by `qk_norm_rope`, no extra transpose needed.
1273    ///
1274    /// In-place append at slot `[nkv, cache_len..cache_len+new_tokens, hd]`.
1275    /// Caller owns `cache_len` bookkeeping.
1276    #[allow(clippy::too_many_arguments)]
1277    fn kv_cache_append_head_major(
1278        ctx: &mut Self::Context,
1279        cache_k: &mut Self::Buffer,
1280        cache_v: &mut Self::Buffer,
1281        cache_len: usize,
1282        cache_capacity: usize,
1283        new_k_head_major: &Self::Buffer,
1284        new_v_head_major: &Self::Buffer,
1285        new_tokens: usize,
1286        nkv: usize,
1287        hd: usize,
1288    );
1289
1290    /// Transpose [heads, tokens, dim] → [tokens, heads, dim].
1291    /// Called after `flash_attention` to restore token-major layout for O-proj.
1292    fn transpose_head_to_token(
1293        ctx: &mut Self::Context,
1294        src: &Self::Buffer,
1295        dst: &mut Self::Buffer,
1296        tokens: usize,
1297        heads: usize,
1298        dim: usize,
1299    );
1300
1301    /// residual[i] += x[i] (in-place)
1302    fn add_inplace(
1303        ctx: &mut Self::Context,
1304        residual: &mut Self::Buffer,
1305        x: &Self::Buffer,
1306        len: usize,
1307    );
1308
1309    /// `dst[i] += scale * src[i]` — scalar-broadcast scaled add, in place.
1310    ///
1311    /// MoE per-token combine writes `out[b] += weight_k * expert_k(x[b])`
1312    /// for each top-K expert; this primitive is the per-call accumulate.
1313    /// Backends without a dedicated kernel can fall back to the default
1314    /// implementation, which round-trips through host memory — correct,
1315    /// but slow on a hot path. Override on any backend you actually
1316    /// dispatch MoE on.
1317    fn scaled_add_inplace(
1318        _ctx: &mut Self::Context,
1319        dst: &mut Self::Buffer,
1320        src: &Self::Buffer,
1321        scale: f32,
1322        len: usize,
1323    ) {
1324        let mut dst_v = Self::to_vec(dst, len);
1325        let src_v = Self::to_vec(src, len);
1326        for i in 0..len {
1327            dst_v[i] += scale * src_v[i];
1328        }
1329        // Move the new buffer into the slot pointed to by `dst`. Safe
1330        // because `Self::Buffer: Send + Sync` and the old buffer is
1331        // dropped here when overwritten.
1332        *dst = Self::from_slice(&dst_v);
1333    }
1334
1335    /// Broadcast bias add: `data[r, c] += bias[c]` for every row.
1336    /// Required by Bert / Clip / Whisper whose linear projections carry a bias.
1337    fn add_bias(
1338        ctx: &mut Self::Context,
1339        data: &mut Self::Buffer,
1340        bias: &Self::Buffer,
1341        rows: usize,
1342        cols: usize,
1343    );
1344
1345    /// Full LayerNorm (mean + variance normalisation + affine), distinct from
1346    /// the `rms_norm` used by Llama-family decoders.
1347    ///   `out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]`
1348    /// Where `mean` and `var` are reduced over the last dim (cols).
1349    #[allow(clippy::too_many_arguments)]
1350    fn layer_norm(
1351        ctx: &mut Self::Context,
1352        x: &Self::Buffer,
1353        gamma: &Self::Buffer,
1354        beta: &Self::Buffer,
1355        eps: f32,
1356        out: &mut Self::Buffer,
1357        tokens: usize,
1358        dim: usize,
1359    );
1360
1361    /// Element-wise GELU activation (erf-based, matches PyTorch default).
1362    fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
1363
1364    // ── Buffer management (context-free) ────────────────────────────────
1365
1366    fn alloc(len: usize) -> Self::Buffer;
1367    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
1368    fn from_slice(data: &[f32]) -> Self::Buffer;
1369
1370    /// Load a weight tensor straight from its on-disk byte representation,
1371    /// letting the backend pick its preferred storage dtype.
1372    ///
1373    /// Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching
1374    /// pre-existing loader behaviour. Backends override this to go straight
1375    /// from raw bytes into a native half-precision buffer (e.g. Metal with
1376    /// `FERRUM_METAL_DTYPE=f16`), avoiding the transient 2× RAM spike.
1377    fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
1378        let data = src_dtype.to_f32_vec(raw);
1379        Self::from_slice(&data)
1380    }
1381
1382    // (The Phase A3 unified `gemm_quant(QuantWeights, QuantKind)` stub
1383    // that used to live here is superseded by the `load_quant` /
1384    // `gemm_quant(QuantStore)` pair earlier in this trait — same idea,
1385    // but the store hides the per-kind buffer layout so callers don't
1386    // have to construct a per-kind `QuantWeights<'_, Self>` packet.)
1387
1388    // ── TP collective ops (Phase A3 stubs) ──────────────────────────────
1389    //
1390    // Default impl is single-rank no-op: `world_size = 1`, `rank = 0`, and
1391    // the collective ops are identity. Multi-GPU backends (future
1392    // CudaBackend + NCCL) override these. Model code can call
1393    // `B::all_reduce_sum(...)` unconditionally; single-GPU paths pay zero.
1394
1395    fn world_size(_ctx: &Self::Context) -> usize {
1396        1
1397    }
1398    fn rank(_ctx: &Self::Context) -> usize {
1399        0
1400    }
1401    fn all_reduce(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp) {
1402        // single-rank: no-op
1403    }
1404    fn all_gather(
1405        _ctx: &mut Self::Context,
1406        _local: &Self::Buffer,
1407        _global: &mut Self::Buffer,
1408        _local_len: usize,
1409    ) {
1410        // single-rank: no-op (caller is expected to handle the degenerate
1411        // case or arrange for `local == global`)
1412    }
1413    fn broadcast(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize) {
1414        // single-rank: no-op
1415    }
1416}