Skip to main content

ferrum_kernels/backend/
capabilities.rs

1//! Optional backend capability traits layered on top of [`Backend`].
2
3use ferrum_types::{FerrumError, Result};
4
5use super::traits::Backend;
6use super::types::{GgufQuantType, MoeRouting, ReduceOp};
7
8// ════════════════════════════════════════════════════════════════════════
9// BackendGraph capability (CUDA Graph capture/replay)
10// ════════════════════════════════════════════════════════════════════════
11//
12// Decode-loop optimization: eliminate per-kernel launch overhead by
13// capturing the full step as a CUDA graph and replaying. CPU/Metal have
14// no equivalent — they `impl BackendGraph for X {}` with empty bodies and
15// inherit the unsupported / no-op defaults below.
16//
17// Flow per decode step:
18//   1. Caller: `set_decode_state(ctx, token, step)` — memcpy to dev bufs
19//   2. Try `replay_graph(ctx, key)`:
20//        - Ok(true):  graph replayed, skip eager forward
21//        - Ok(false): no captured graph yet, run eager
22//        - Err(_):    not supported, run eager
23//   3. If running eager and in capture window:
24//      - `set_dev_state_mode(ctx, true)` so kernels use _dyn variants
25//      - `begin_graph_capture(ctx)`
26//      - run forward
27//      - `end_graph_capture(ctx, key)` — stores graph on ctx internally
28//      - `set_dev_state_mode(ctx, false)` — restore scalar kernels
29
30/// Capability-trait for backends that can capture and replay execution as
31/// a graph (CUDA Graph). Models that call these methods bound their
32/// generic on `B: BackendGraph`; backends without graph support
33/// (Metal, CPU) impl this trait with an empty body and inherit
34/// no-op / `unsupported` defaults.
35pub trait BackendGraph: Backend {
36    /// Update per-step dynamic state (token id, step/pos). Fast (3x memcpy).
37    fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) {}
38
39    /// Toggle between scalar-arg kernels (normal) and `_dyn` kernels that
40    /// read their dynamic scalar args from device memory (graph-friendly).
41    fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) {}
42
43    /// Begin stream capture. Subsequent kernel launches are recorded into
44    /// a pending graph instead of executing eagerly.
45    fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
46        Err(FerrumError::unsupported("graph capture not supported"))
47    }
48
49    /// End stream capture and install the captured graph keyed by
50    /// `_key` (opaque caller-chosen u64; the model uses `m_padded` so
51    /// that different batch shapes don't thrash a single slot).
52    fn end_graph_capture(_ctx: &mut Self::Context, _key: u64) -> Result<()> {
53        Err(FerrumError::unsupported("graph capture not supported"))
54    }
55
56    /// Replay the captured graph for `_key`. Returns `Ok(false)` if no
57    /// graph is cached for that key; caller should run eager.
58    fn replay_graph(_ctx: &mut Self::Context, _key: u64) -> Result<bool> {
59        Ok(false)
60    }
61
62    /// Drop the cached graph for `_key` — required when its kernel-arg
63    /// pointers (KV cache, scratch) might no longer be valid. Use
64    /// `reset_all_graphs` when EVERY cached graph should be evicted
65    /// (hard model reload / scratch realloc).
66    fn reset_graph(_ctx: &mut Self::Context, _key: u64) {}
67
68    /// Drop ALL cached graphs — used by hard reset paths.
69    fn reset_all_graphs(_ctx: &mut Self::Context) {}
70}
71
72// ════════════════════════════════════════════════════════════════════════
73// BackendCollective capability (NCCL / RCCL multi-rank ops)
74// ════════════════════════════════════════════════════════════════════════
75//
76// Tensor-parallel multi-GPU collective ops. CUDA wires these to NCCL via
77// `crate::nccl_comm::NcclRank`; AMD would wire to RCCL similarly. CPU and
78// Metal `impl BackendCollective for X {}` with empty bodies, inheriting
79// single-rank no-ops (world_size=1, rank=0, ops are identity).
80
81/// Capability-trait for backends that support multi-rank collective ops.
82/// Single-GPU backends inherit the no-op defaults: `world_size = 1`,
83/// `rank = 0`, and the collective ops are identity. Multi-rank backends
84/// (CUDA + NCCL today, AMD + RCCL in the future) override these.
85pub trait BackendCollective: Backend {
86    fn world_size(_ctx: &Self::Context) -> usize {
87        1
88    }
89    fn rank(_ctx: &Self::Context) -> usize {
90        0
91    }
92    fn all_reduce(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp) {
93        // single-rank: no-op
94    }
95    fn all_gather(
96        _ctx: &mut Self::Context,
97        _local: &Self::Buffer,
98        _global: &mut Self::Buffer,
99        _local_len: usize,
100    ) {
101        // single-rank: no-op (caller is expected to handle the degenerate
102        // case or arrange for `local == global`)
103    }
104    fn broadcast(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize) {
105        // single-rank: no-op
106    }
107}
108
109// ════════════════════════════════════════════════════════════════════════
110// BackendQuantMarlin capability (CUDA Marlin INT4 / GPTQ)
111// ════════════════════════════════════════════════════════════════════════
112//
113// CUDA-specific INT4 GEMM via Marlin tile kernels (Tensor Core required).
114// Metal/CPU don't have Marlin; they `impl BackendQuantMarlin for X {}` empty
115// and inherit `unsupported` Err defaults. GPTQ models targeting non-CUDA
116// backends are loaded via the dequant-fallback path in the Linear impls.
117
118/// Capability-trait for backends that natively support Marlin INT4 GEMM.
119/// CUDA wires this to the Marlin (or vLLM marlin_moe_wna16) tile kernels;
120/// other backends inherit defaults that error or no-op.
121pub trait BackendQuantMarlin: Backend {
122    /// Repack raw GPTQ tensors into a backend-specific `Linear<Self>` impl.
123    /// Called once per layer at model load time.
124    ///
125    /// Inputs are host-side slices (CPU memory) — the loader reads from
126    /// safetensors and hands them off; each backend uploads + repacks
127    /// per its own strategy. `bits` is typically 4; `group_size` is
128    /// typically 128. `bias_host` is optional `[out_features]` f32 (when
129    /// the model has fused bias, e.g. Qwen2.5 attention projections).
130    ///
131    /// Phase 3e/2: returns `Box<dyn Linear<Self>>` directly (CUDA:
132    /// `CudaMarlinLinear`, CPU: `CpuGptqLinear`). Kernel dispatch lives
133    /// inside the boxed Linear's `forward` — the old `gemm_gptq` trait
134    /// method is gone.
135    #[allow(clippy::too_many_arguments)]
136    fn load_gptq(
137        _qweight: &[i32],
138        _scales: &[f32],
139        _qzeros: &[i32],
140        _g_idx: Option<&[i32]>,
141        _bias_host: Option<&[f32]>,
142        _bits: u32,
143        _group_size: usize,
144        _k: usize,
145        _n: usize,
146    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
147        Err(FerrumError::unsupported(
148            "load_gptq not implemented for this backend",
149        ))
150    }
151    /// Load num_experts GPTQ weight tiles into ONE stacked store, with
152    /// the property that **each expert's packed bytes are contiguous**
153    /// in the resulting store. This is what the offset GEMM needs to
154    /// dispatch per expert via pointer offset alone.
155    ///
156    /// Why this is a separate API from `load_gptq` + post-hoc concat:
157    /// Marlin's repack permutes data in `[K-tile-row outer, N-tile inner]`
158    /// order. A single repack of `concat(all experts along N)` produces
159    /// a buffer where expert e's bytes are spread across K-tile-rows,
160    /// NOT contiguous. Per-expert repack-then-concat keeps each
161    /// expert's data in one contiguous block.
162    ///
163    /// `qweights[i] / scales[i] / qzeros[i]` are each expert's raw GPTQ
164    /// tensors. All share the same K + group_size + bits + g_idx.
165    ///
166    /// Default returns Err(unsupported); override on backends with a
167    /// per-expert MoE GPTQ path.
168    /// Phase C step 4e: returns the trait-object `MarlinExpertStack`
169    /// directly. Internally, each backend constructs its own opaque
170    /// repacked tile (Marlin: per-expert-then-concat; CPU: dequantized
171    /// f32 weight slab) and wraps it in the concrete
172    /// `{Cuda,Cpu}MarlinExpertStack` impl.
173    ///
174    /// Removing `Self::GptqStore` from the public API kills the type
175    /// leak that previously forced `ExpertStack<B>` to carry
176    /// `Option<Arc<B::GptqStore>>`. Adding a new Marlin backend now
177    /// only requires implementing this method + a fresh
178    /// `MarlinExpertStack<NewBackend>` impl — no Backend trait edits.
179    #[allow(clippy::too_many_arguments)]
180    fn load_gptq_stacked(
181        _qweights: &[&[i32]],
182        _scales: &[&[f32]],
183        _qzeros: &[&[i32]],
184        _g_idx: Option<&[i32]>,
185        _bits: u32,
186        _group_size: usize,
187        _k: usize,
188        _n_per_expert: usize,
189    ) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
190        Err(FerrumError::unsupported(
191            "load_gptq_stacked not implemented for this backend",
192        ))
193    }
194    // Phase C step 4a: marlin_zero_stacked_workspace — body inlined into
195    // MarlinExpertStack::zero_workspace.
196    // Phase C step 4b: make_stacked_expert_linear — body inlined into
197    // MarlinExpertStack::make_expert_linear.
198    // Phase C step 4c+4d: moe_gemm_phase_batched + moe_gemm_phase_vllm —
199    // bodies inlined into MarlinExpertStack::gemm_phase_batched /
200    // gemm_phase_vllm (concrete impls in quant_linear/{cuda,cpu}_marlin_stack.rs).
201    // Phase C step 4e: make_marlin_expert_stack subsumed by
202    // load_gptq_stacked (now returns the trait object directly).
203    // gemm_gptq_with_offset_strided — body inlined into CpuMarlinExpertStack
204    // (the only remaining caller).
205    /// Pre-grow any backend-internal scratch slots whose size depends
206    /// on `m_total * intermediate_size` (the largest matmul fan-in
207    /// inside `unified_forward_internal`). Default no-op. CUDA
208    /// implements this to grow the perm-aware Marlin gather scratch
209    /// EAGERLY before the caller enters a CUDA-graph capture region —
210    /// `cuLaunchKernel` after a runtime alloc inside a captured
211    /// stream returns `CUDA_ERROR_INVALID_VALUE`.
212    fn pregrow_marlin_gather_scratch(_ctx: &mut Self::Context, _required: usize) {
213        // default: no scratch to pre-grow
214    }
215    // Phase C step 4e: gemm_gptq_with_offset_strided removed —
216    // body inlined into CpuMarlinExpertStack (the only caller after
217    // step 4c moved the multi-stream pool dispatch into the CUDA
218    // free function).
219}
220
221// ════════════════════════════════════════════════════════════════════════
222// BackendQuantGguf capability (Metal GGUF Q4_K / Q6_K / Q8_0)
223// ════════════════════════════════════════════════════════════════════════
224//
225// Metal-specific GGUF k-quant GEMM/GEMV via simdgroup_matmul shaders.
226// CUDA/CPU don't ship GGUF kernels; they `impl BackendQuantGguf for X {}`
227// empty and inherit unsupported defaults. GGUF models targeting non-Metal
228// backends are loaded via dequant-fallback in the Linear impls.
229
230/// Capability-trait for backends that natively dispatch GGUF k-quant
231/// GEMM / GEMV. Metal wires its q4k/q6k shaders here; CUDA/CPU inherit
232/// defaults that error.
233pub trait BackendQuantGguf: Backend {
234    /// Load GGUF k-quant weights into the backend's preferred format.
235    ///
236    /// `kind` discriminates Q4_K / Q5_K / Q6_K / Q8_0 etc. The CPU path
237    /// typically eager-dequants to fp32; the Metal path keeps raw block
238    /// bytes in MTLBuffer and dequants per matmul into a transient fp16
239    /// buffer. Adding a new k-quant flavour is a matched pair of
240    /// `QuantStore` variant + `match` arm, not a new trait method.
241    ///
242    /// `bytes`: contiguous on-disk payload — `n_blocks × block_size`.
243    /// `n_rows`: out_features. `n_cols`: in_features. The block count
244    /// is derived per-kind from these dims.
245    fn load_quant(
246        _kind: GgufQuantType,
247        _bytes: &[u8],
248        _n_rows: usize,
249        _n_cols: usize,
250    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
251        Err(FerrumError::unsupported(
252            "load_quant not implemented for this backend",
253        ))
254    }
255    /// Build a fused linear from multiple `(kind, bytes, n_rows)`
256    /// parts that share `n_cols`. Used by `GgufLoader::load_fused` when
257    /// parts have heterogeneous quant kinds (e.g. Qwen3 qkv_proj where
258    /// q+k are Q4_K but v is Q6_K) — byte-concatenation isn't possible,
259    /// so each part stays as its own QuantStore and the gemm dispatches
260    /// one matvec per part with output offsets.
261    ///
262    /// Phase 3e/3: returns `Box<dyn Linear<Self>>` directly (Metal:
263    /// `MetalGgufLinear` over a `Fused` MetalQuantStore variant).
264    fn load_quant_fused(
265        _parts: &[(GgufQuantType, &[u8], usize)],
266        _n_cols: usize,
267    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
268        Err(FerrumError::unsupported(
269            "load_quant_fused not implemented for this backend",
270        ))
271    }
272    /// Build a stacked-experts MoE linear from a contiguous 3-D weight
273    /// payload `[num_experts, n_rows, n_cols/256]` super-blocks. Used for
274    /// the MoE indirect-dispatch fast path; backends without such a kernel
275    /// return `Err(unsupported)` and the model code falls back to the
276    /// per-expert `Box<dyn Linear<Self>>` loop.
277    ///
278    /// Phase 3e/4: returns `Box<dyn StackedExpertGgufLinear<Self>>` directly
279    /// (Metal: `MetalStackedExpertGgufLinear` over Q4KExperts / Q6KExperts).
280    /// Replaces the old `Result<Self::QuantStore>` API + the 7 sibling
281    /// `*_moe_id*` Backend methods that consumed it.
282    fn load_quant_experts(
283        _kind: GgufQuantType,
284        _bytes: &[u8],
285        _num_experts: usize,
286        _n_rows: usize,
287        _n_cols: usize,
288    ) -> Result<Box<dyn crate::StackedExpertGgufLinear<Self>>> {
289        Err(FerrumError::unsupported(
290            "load_quant_experts not implemented for this backend",
291        ))
292    }
293}
294
295// ════════════════════════════════════════════════════════════════════════
296// BackendMoeFused capability (MoE routing + post-ops kernels)
297// ════════════════════════════════════════════════════════════════════════
298//
299// Backend-specific MoE infrastructure: routing index buffers, expert
300// dispatch align, weighted sum / silu/mul fused ops, top-k softmax.
301// CUDA + Metal both implement (they're the real MoE backends);
302// CPU inherits unsupported defaults.
303
304/// Capability-trait for backends that natively dispatch MoE post-ops + routing.
305pub trait BackendMoeFused: Backend {
306    /// Routing inputs for `moe_gemm_phase_vllm` — host-built i32 arrays
307    /// uploaded once per layer (or per token, depending on caller cadence).
308    /// Matches the shape contract of `moe_align_block_size` outputs but is
309    /// usable on backends that build the indices on host.
310    ///
311    /// Buffers are typed Self::Buffer (= fp16 on CUDA) for trait-object
312    /// reasons; backends reinterpret as i32. Default returns unsupported.
313    fn upload_moe_routing(
314        _ctx: &mut Self::Context,
315        _sorted_token_ids: &[i32],
316        _expert_ids: &[i32],
317        _num_tokens_past_padded: &[i32],
318    ) -> Result<MoeRouting<Self>> {
319        Err(FerrumError::unsupported(
320            "upload_moe_routing not implemented for this backend",
321        ))
322    }
323    /// GPU-side MoE router: `[batch, num_experts]` logits → `[batch, top_k]`
324    /// expert IDs (i32) + `[batch, top_k]` combine weights (f32).
325    ///
326    /// Replaces the per-layer `B::sync + B::to_vec(router_logits) + host route()`
327    /// round trip. The output buffers stay device-side for downstream
328    /// `gemv_quant_moe_id` / `gemm_quant_moe_id` consumption — no host
329    /// pipeline drain in the inner loop.
330    ///
331    /// `norm_topk_prob`: if true, divide each row's K weights by their
332    /// sum so they total 1.0 (Qwen3-MoE / Mixtral default).
333    #[allow(clippy::too_many_arguments)]
334    fn route_topk_softmax(
335        _ctx: &mut Self::Context,
336        _logits: &Self::Buffer,
337        _out_ids: &mut Self::Buffer,
338        _out_weights: &mut Self::Buffer,
339        _batch: usize,
340        _num_experts: usize,
341        _top_k: usize,
342        _norm_topk_prob: bool,
343    ) -> Result<()> {
344        Err(FerrumError::unsupported(
345            "route_topk_softmax not implemented for this backend",
346        ))
347    }
348    /// GPU-side fast-path for the host route() leg of the bucketed
349    /// MoE forward (`moe_forward_bucketed` in ferrum-models). Replaces
350    /// the `B::sync(ctx) + B::to_vec(logits) + crate::moe::router::
351    /// route_into(...)` triple with a single GPU kernel + small D2H of
352    /// `[batch, top_k]` ids + weights.
353    ///
354    /// The backend allocates / reuses its own device-side scratch for
355    /// the kernel output; the caller only provides the host destination
356    /// vectors (resized + overwritten on each call). Default impl
357    /// returns `Err(unsupported)` so non-CUDA callers stay on the host
358    /// route_into() path with no behavior change.
359    #[allow(clippy::too_many_arguments)]
360    fn try_gpu_route_topk_into_host(
361        _ctx: &mut Self::Context,
362        _logits_dev: &Self::Buffer,
363        _out_ids_host: &mut Vec<u32>,
364        _out_weights_host: &mut Vec<f32>,
365        _batch: usize,
366        _num_experts: usize,
367        _top_k: usize,
368        _norm_topk_prob: bool,
369    ) -> Result<()> {
370        Err(FerrumError::unsupported(
371            "try_gpu_route_topk_into_host not implemented for this backend",
372        ))
373    }
374    /// GPU-side moe_align_block_size — prep for a future fused MoE
375    /// Marlin kernel. Takes per-pair expert assignments (from
376    /// [`Self::route_topk_softmax`]) and produces:
377    ///   - `sorted_token_ids[N_padded]`: flat list of pair indices
378    ///     in [0, batch * top_k), sorted by their assigned expert and
379    ///     padded with sentinel `batch * top_k` inside each expert
380    ///     group up to a `block_size` boundary.
381    ///   - `block_ids[N_padded / block_size]`: which expert each
382    ///     `block_size`-row tile of `sorted_token_ids` belongs to.
383    ///   - `total_tokens_post_pad[1]`: actual padded token count.
384    ///
385    /// Layout matches vLLM's marlin_moe_wna16 kernel input
386    /// expectation. The fused Marlin kernel reads a row from
387    /// `a[sorted_token_ids[i] / top_k]` and weights from
388    /// Build `pairs_by_token` + `packed_token_idx` device-side from
389    /// device-side `expert_ids`. The counting-sort permutation that
390    /// lets `moe_combine` (and the gather step before phase 1 GEMM)
391    /// read routing output without a host round-trip — the prerequisite
392    /// for graph-capturing the MoE bucketed path.
393    ///
394    /// Inputs (device):
395    /// - `expert_ids: I32 [batch * top_k]` — top-K selected expert ids.
396    ///
397    /// Outputs (device):
398    /// - `pairs_by_token: I32 [batch * top_k]` — sorted-by-expert
399    ///   position of each (b, k) pair (the row index into `packed_down`
400    ///   that `moe_combine` reads).
401    /// - `packed_token_idx: I32 [batch * top_k]` — inverse map: for
402    ///   each packed row, the original token b. Used by the gather
403    ///   step (`embedding_lookup` of `x` into `x_packed` before phase 1).
404    /// - `expert_offsets: I32 [num_experts + 1]` — exclusive prefix
405    ///   sum of tokens-per-expert; phase 1/3 dispatchers use it to
406    ///   compute each expert's row slice in the packed buffers.
407    ///
408    /// Default impl returns Err — only CUDA implements this.
409    #[allow(clippy::too_many_arguments)]
410    fn moe_build_pairs_by_token(
411        _ctx: &mut Self::Context,
412        _expert_ids: &Self::Buffer,
413        _pairs_by_token: &mut Self::Buffer,
414        _packed_token_idx: &mut Self::Buffer,
415        _expert_offsets: &mut Self::Buffer,
416        _batch_x_topk: usize,
417        _num_experts: usize,
418        _top_k: usize,
419    ) -> Result<()> {
420        Err(FerrumError::unsupported(
421            "moe_build_pairs_by_token not implemented for this backend",
422        ))
423    }
424
425    /// `b[block_ids[blockIdx.y] * n_per_expert + ...]`.
426    ///
427    /// Default impl returns Err — only CUDA implements this.
428    #[allow(clippy::too_many_arguments)]
429    fn moe_align_block_size(
430        _ctx: &mut Self::Context,
431        _expert_ids_per_pair: &Self::Buffer,
432        _sorted_token_ids: &mut Self::Buffer,
433        _block_ids: &mut Self::Buffer,
434        _total_tokens_post_pad: &mut Self::Buffer,
435        _batch_x_topk: usize,
436        _num_experts: usize,
437        _block_size: usize,
438        _sorted_max_size: usize,
439    ) -> Result<()> {
440        Err(FerrumError::unsupported(
441            "moe_align_block_size not implemented for this backend",
442        ))
443    }
444
445    /// vLLM-native align variant: `sorted_token_ids` stores flattened
446    /// `(token, top_k_slot)` pair ids, not Ferrum's pre-gathered packed rows.
447    /// This lets marlin_moe read gate_up input as `A[pair_id / top_k]`.
448    #[allow(clippy::too_many_arguments)]
449    fn moe_align_block_size_pair_ids(
450        _ctx: &mut Self::Context,
451        _expert_ids_per_pair: &Self::Buffer,
452        _sorted_token_ids: &mut Self::Buffer,
453        _block_ids: &mut Self::Buffer,
454        _total_tokens_post_pad: &mut Self::Buffer,
455        _batch_x_topk: usize,
456        _num_experts: usize,
457        _block_size: usize,
458        _sorted_max_size: usize,
459    ) -> Result<()> {
460        Err(FerrumError::unsupported(
461            "moe_align_block_size_pair_ids not implemented for this backend",
462        ))
463    }
464
465    /// GPU-side bucket sort: turn `[batch, top_k]` selected expert IDs
466    /// (from [`Self::route_topk_softmax`]) into `tpe[num_experts]` /
467    /// `ids[num_experts * row_stride]` arrays consumed by the batched
468    /// MoE GEMM, and emit indirect-dispatch args for the consumer GEMM.
469    ///
470    /// The `ids` buffer's row stride is `batch * top_k` (worst case);
471    /// only the first `tpe[e]` entries of each row are populated. The
472    /// consumer GEMM kernel early-exits at `r1 >= tpe[e]`, so the over-
473    /// strided indices cost nothing in the inner loop. The grid size,
474    /// however, would still be worst-case unless we tighten it — this
475    /// is what the `gate_up_args` / `down_args` outputs do: a 12-byte
476    /// `(grid_x, grid_y, grid_z)` u32 triple per shape, ready for
477    /// `dispatch_thread_groups_indirect`. `grid_x` is shared (depends
478    /// only on `max(tpe[e])`); `grid_y` differs because gate/up has
479    /// `M = m_gate_up` while down has `M = m_down`.
480    ///
481    /// All five output buffers are written in one kernel; no host
482    /// roundtrip and no per-layer pipeline drain.
483    #[allow(clippy::too_many_arguments)]
484    fn compute_ids_tpe_gpu(
485        _ctx: &mut Self::Context,
486        _selected_ids: &Self::Buffer,
487        _tpe: &mut Self::Buffer,
488        _ids: &mut Self::Buffer,
489        _gate_up_args: &mut Self::Buffer,
490        _down_args: &mut Self::Buffer,
491        _batch: usize,
492        _num_experts: usize,
493        _top_k: usize,
494        _m_gate_up: usize,
495        _m_down: usize,
496    ) -> Result<()> {
497        Err(FerrumError::unsupported(
498            "compute_ids_tpe_gpu not implemented for this backend",
499        ))
500    }
501    /// Stacked SiLU·gate over `[batch * top_k, ffn]` rows (prefill version
502    /// of `silu_mul_stacked`).
503    fn silu_mul_batched(
504        _ctx: &mut Self::Context,
505        _gate: &Self::Buffer,
506        _up: &Self::Buffer,
507        _out: &mut Self::Buffer,
508        _total_pairs: usize,
509        _ffn: usize,
510    ) -> Result<()> {
511        Err(FerrumError::unsupported(
512            "silu_mul_batched not implemented for this backend",
513        ))
514    }
515    /// Fused weighted-sum + residual-add: `residual[i] += Σ_k weights[k] · slots[k, i]`.
516    /// Single dispatch replaces the (weighted_sum → moe_out) +
517    /// (add_inplace residual += moe_out) pair on the decode hot path.
518    fn weighted_sum_residual_stacked(
519        _ctx: &mut Self::Context,
520        _slots: &Self::Buffer,
521        _weights: &Self::Buffer,
522        _residual: &mut Self::Buffer,
523        _n_slots: usize,
524        _hidden: usize,
525    ) -> Result<()> {
526        Err(FerrumError::unsupported(
527            "weighted_sum_residual_stacked not implemented for this backend",
528        ))
529    }
530    /// Fused weighted-sum-residual + RMSNorm: combines this layer's
531    /// `weighted_sum_residual_stacked` with the next layer's leading
532    /// `rms_norm` into a single dispatch.
533    ///
534    /// Computes
535    ///   `residual[i] += Σ_s w[s] · slots[s, i]`
536    ///   `normed_out[i] = residual[i] · (1 / sqrt(Σ residual² / hidden + eps)) · next_norm_w[i]`
537    ///
538    /// Caller is responsible for skipping the next layer's standalone
539    /// `rms_norm` — `normed_out` IS that layer's `norm_out` input.
540    /// Default returns Unsupported.
541    #[allow(clippy::too_many_arguments)]
542    fn weighted_sum_residual_norm_stacked(
543        _ctx: &mut Self::Context,
544        _slots: &Self::Buffer,
545        _weights: &Self::Buffer,
546        _residual: &mut Self::Buffer,
547        _next_norm_w: &Self::Buffer,
548        _normed_out: &mut Self::Buffer,
549        _n_slots: usize,
550        _hidden: usize,
551        _eps: f32,
552    ) -> Result<()> {
553        Err(FerrumError::unsupported(
554            "weighted_sum_residual_norm_stacked not implemented for this backend",
555        ))
556    }
557    /// Per-batch weighted sum: `out[b, h] = Σ_k weights[b, k] · slots[b, k, h]`.
558    /// Single dispatch covers the whole batch (prefill version of
559    /// `weighted_sum_stacked` which only handled one token).
560    fn weighted_sum_batched(
561        _ctx: &mut Self::Context,
562        _slots: &Self::Buffer,
563        _weights: &Self::Buffer,
564        _out: &mut Self::Buffer,
565        _batch: usize,
566        _top_k: usize,
567        _hidden: usize,
568    ) -> Result<()> {
569        Err(FerrumError::unsupported(
570            "weighted_sum_batched not implemented for this backend",
571        ))
572    }
573    /// Offset-aware variant of [`Self::weighted_sum_batched`] —
574    /// `weights` reads from `weights_offset` (in elements, points at
575    /// the start of `[batch, top_k]`), `out` writes from `out_offset`
576    /// (in elements, points at start of `[batch, hidden]`). Used by
577    /// the per-item batched-decode path to skip `copy_slice` round-trips.
578    /// Default falls back to the non-offset variant via two copies.
579    #[allow(clippy::too_many_arguments)]
580    fn weighted_sum_batched_offset(
581        ctx: &mut Self::Context,
582        slots: &Self::Buffer,
583        weights: &Self::Buffer,
584        weights_offset: usize,
585        out: &mut Self::Buffer,
586        out_offset: usize,
587        batch: usize,
588        top_k: usize,
589        hidden: usize,
590    ) -> Result<()> {
591        // Default: stage through scratch — backends override for zero-copy.
592        let _ = (
593            ctx,
594            slots,
595            weights,
596            weights_offset,
597            out,
598            out_offset,
599            batch,
600            top_k,
601            hidden,
602        );
603        Err(FerrumError::unsupported(
604            "weighted_sum_batched_offset not implemented for this backend",
605        ))
606    }
607    /// Stacked SiLU·gate over `[n_slots, ffn]` rows.
608    ///
609    /// Computes `out[s, i] = silu(gate[s, i]) * up[s, i]` for each slot
610    /// `s`, element `i`. Single dispatch covers all slots — cuts the
611    /// MoE decode silu staging from `top_k * (3 copy_slice + 1 silu)`
612    /// = 32 dispatches per layer to 1.
613    fn silu_mul_stacked(
614        _ctx: &mut Self::Context,
615        _gate: &Self::Buffer,
616        _up: &Self::Buffer,
617        _out: &mut Self::Buffer,
618        _n_slots: usize,
619        _ffn: usize,
620    ) -> Result<()> {
621        Err(FerrumError::unsupported(
622            "silu_mul_stacked not implemented for this backend",
623        ))
624    }
625    /// Capability probe for [`Self::gemv_quant_moe_id_gate_up_silu`].
626    ///
627    /// `true` ⇒ the fused kernel is wired in and the caller should
628    /// prefer it on the MoE decode hot path. `false` ⇒ caller must use
629    /// the 3-dispatch fallback (gate gemv + up gemv + silu_mul_stacked).
630    /// Lets callers branch without paying the cost of an `Err(Unsupported)`
631    /// allocation per (layer, step).
632    fn supports_fused_moe_gate_up_silu() -> bool {
633        false
634    }
635    /// Capability probe for [`Self::gemv_quant_moe_id_batched`].
636    fn supports_batched_moe_gemv() -> bool {
637        false
638    }
639    /// Capability probe for [`Self::gemv_quant_moe_id_gate_up_silu_batched`].
640    fn supports_batched_moe_gate_up_silu() -> bool {
641        false
642    }
643    /// Weighted sum across `n_slots` rows of `[hidden]`.
644    ///
645    /// Computes `out[i] = Σ_s weights[s] * slots[s, i]`. Single
646    /// dispatch replaces the per-slot `(copy_slice + scaled_add)`
647    /// loop in the MoE decode path (16 dispatches per layer → 1).
648    fn weighted_sum_stacked(
649        _ctx: &mut Self::Context,
650        _slots: &Self::Buffer,
651        _weights: &Self::Buffer,
652        _out: &mut Self::Buffer,
653        _n_slots: usize,
654        _hidden: usize,
655    ) -> Result<()> {
656        Err(FerrumError::unsupported(
657            "weighted_sum_stacked not implemented for this backend",
658        ))
659    }
660    /// MoE combine: per-token weighted sum across `top_k` expert outputs.
661    ///
662    /// After the bucketed dispatch, `packed_down` holds `[total_pairs,
663    /// hidden]` with one row per (token, k_slot) pair in expert-bucketed
664    /// order. `pairs_by_token[b * top_k + k]` is the inverse map: which
665    /// row of `packed_down` carries the (b, k_slot) contribution. A row
666    /// index of `-1` means "skip" (unused slot).
667    ///
668    /// Computes:
669    ///
670    /// ```text
671    /// out[b, h] = sum_k pair_weights[b * top_k + k] *
672    ///                   packed_down[pairs_by_token[b * top_k + k], h]
673    /// ```
674    ///
675    /// Default impl round-trips via host memory — correct but slow.
676    /// CUDA backend launches a single fused kernel.
677    ///
678    /// Phase D follow-up: `pairs_by_token` (I32) and `pair_weights` (F32)
679    /// are now device buffers so callers can build them on-device for
680    /// graph capture (was `&[i32]` / `&[f32]` host slices with internal
681    /// clone_htod, which records stale host pointers under CUDA Graph
682    /// capture replay).
683    #[allow(clippy::too_many_arguments)]
684    fn moe_combine(
685        ctx: &mut Self::Context,
686        packed_down: &Self::Buffer,
687        pairs_by_token: &Self::Buffer,
688        pair_weights: &Self::Buffer,
689        out: &mut Self::Buffer,
690        batch: usize,
691        hidden: usize,
692        top_k: usize,
693        total_pairs: usize,
694    ) {
695        // Reference default: D2H pairs/weights, run the host loop, H2D out.
696        // CUDA backend overrides with a single device kernel.
697        let _ = ctx;
698        let packed = Self::to_vec(packed_down, total_pairs * hidden);
699        let pairs_host_f32 = Self::to_vec(pairs_by_token, batch * top_k);
700        let weights_host = Self::to_vec(pair_weights, batch * top_k);
701        let mut out_h = vec![0.0f32; batch * hidden];
702        for b in 0..batch {
703            for k in 0..top_k {
704                // `to_vec` returns f32; the device-side I32 buffer is
705                // bit-cast to f32 by the trait's f16-default to_vec path,
706                // so we re-extract via raw transmute. Backends override
707                // this default with a typed kernel that doesn't go
708                // through f16; on the default path callers are CPU
709                // parity tests where the byte pattern is preserved.
710                let pair_row = pairs_host_f32[b * top_k + k].to_bits() as i32;
711                if pair_row < 0 {
712                    continue;
713                }
714                let w = weights_host[b * top_k + k];
715                let src = &packed[(pair_row as usize) * hidden..(pair_row as usize + 1) * hidden];
716                let dst = &mut out_h[b * hidden..(b + 1) * hidden];
717                for h in 0..hidden {
718                    dst[h] += w * src[h];
719                }
720            }
721        }
722        *out = Self::from_slice(&out_h);
723    }
724}