ferrum_kernels/backend/capabilities.rs
1//! Optional backend capability traits layered on top of [`Backend`].
2
3use ferrum_types::{FerrumError, Result};
4
5use super::traits::Backend;
6use super::types::{GgufQuantType, MoeRouting, ReduceOp};
7
8// ════════════════════════════════════════════════════════════════════════
9// BackendGraph capability (CUDA Graph capture/replay)
10// ════════════════════════════════════════════════════════════════════════
11//
12// Decode-loop optimization: eliminate per-kernel launch overhead by
13// capturing the full step as a CUDA graph and replaying. CPU/Metal have
14// no equivalent — they `impl BackendGraph for X {}` with empty bodies and
15// inherit the unsupported / no-op defaults below.
16//
17// Flow per decode step:
18// 1. Caller: `set_decode_state(ctx, token, step)` — memcpy to dev bufs
19// 2. Try `replay_graph(ctx, key)`:
20// - Ok(true): graph replayed, skip eager forward
21// - Ok(false): no captured graph yet, run eager
22// - Err(_): not supported, run eager
23// 3. If running eager and in capture window:
24// - `set_dev_state_mode(ctx, true)` so kernels use _dyn variants
25// - `begin_graph_capture(ctx)`
26// - run forward
27// - `end_graph_capture(ctx, key)` — stores graph on ctx internally
28// - `set_dev_state_mode(ctx, false)` — restore scalar kernels
29
30/// Capability-trait for backends that can capture and replay execution as
31/// a graph (CUDA Graph). Models that call these methods bound their
32/// generic on `B: BackendGraph`; backends without graph support
33/// (Metal, CPU) impl this trait with an empty body and inherit
34/// no-op / `unsupported` defaults.
35pub trait BackendGraph: Backend {
36 /// Update per-step dynamic state (token id, step/pos). Fast (3x memcpy).
37 fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) {}
38
39 /// Toggle between scalar-arg kernels (normal) and `_dyn` kernels that
40 /// read their dynamic scalar args from device memory (graph-friendly).
41 fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) {}
42
43 /// Begin stream capture. Subsequent kernel launches are recorded into
44 /// a pending graph instead of executing eagerly.
45 fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
46 Err(FerrumError::unsupported("graph capture not supported"))
47 }
48
49 /// End stream capture and install the captured graph keyed by
50 /// `_key` (opaque caller-chosen u64; the model uses `m_padded` so
51 /// that different batch shapes don't thrash a single slot).
52 fn end_graph_capture(_ctx: &mut Self::Context, _key: u64) -> Result<()> {
53 Err(FerrumError::unsupported("graph capture not supported"))
54 }
55
56 /// Replay the captured graph for `_key`. Returns `Ok(false)` if no
57 /// graph is cached for that key; caller should run eager.
58 fn replay_graph(_ctx: &mut Self::Context, _key: u64) -> Result<bool> {
59 Ok(false)
60 }
61
62 /// Drop the cached graph for `_key` — required when its kernel-arg
63 /// pointers (KV cache, scratch) might no longer be valid. Use
64 /// `reset_all_graphs` when EVERY cached graph should be evicted
65 /// (hard model reload / scratch realloc).
66 fn reset_graph(_ctx: &mut Self::Context, _key: u64) {}
67
68 /// Drop ALL cached graphs — used by hard reset paths.
69 fn reset_all_graphs(_ctx: &mut Self::Context) {}
70}
71
72// ════════════════════════════════════════════════════════════════════════
73// BackendCollective capability (NCCL / RCCL multi-rank ops)
74// ════════════════════════════════════════════════════════════════════════
75//
76// Tensor-parallel multi-GPU collective ops. CUDA wires these to NCCL via
77// `crate::nccl_comm::NcclRank`; AMD would wire to RCCL similarly. CPU and
78// Metal `impl BackendCollective for X {}` with empty bodies, inheriting
79// single-rank no-ops (world_size=1, rank=0, ops are identity).
80
81/// Capability-trait for backends that support multi-rank collective ops.
82/// Single-GPU backends inherit the no-op defaults: `world_size = 1`,
83/// `rank = 0`, and the collective ops are identity. Multi-rank backends
84/// (CUDA + NCCL today, AMD + RCCL in the future) override these.
85pub trait BackendCollective: Backend {
86 fn world_size(_ctx: &Self::Context) -> usize {
87 1
88 }
89 fn rank(_ctx: &Self::Context) -> usize {
90 0
91 }
92 fn all_reduce(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp) {
93 // single-rank: no-op
94 }
95 fn all_gather(
96 _ctx: &mut Self::Context,
97 _local: &Self::Buffer,
98 _global: &mut Self::Buffer,
99 _local_len: usize,
100 ) {
101 // single-rank: no-op (caller is expected to handle the degenerate
102 // case or arrange for `local == global`)
103 }
104 fn broadcast(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize) {
105 // single-rank: no-op
106 }
107}
108
109// ════════════════════════════════════════════════════════════════════════
110// BackendQuantMarlin capability (CUDA Marlin INT4 / GPTQ)
111// ════════════════════════════════════════════════════════════════════════
112//
113// CUDA-specific INT4 GEMM via Marlin tile kernels (Tensor Core required).
114// Metal/CPU don't have Marlin; they `impl BackendQuantMarlin for X {}` empty
115// and inherit `unsupported` Err defaults. GPTQ models targeting non-CUDA
116// backends are loaded via the dequant-fallback path in the Linear impls.
117
118/// Capability-trait for backends that natively support Marlin INT4 GEMM.
119/// CUDA wires this to the Marlin (or vLLM marlin_moe_wna16) tile kernels;
120/// other backends inherit defaults that error or no-op.
121pub trait BackendQuantMarlin: Backend {
122 /// Repack raw GPTQ tensors into a backend-specific `Linear<Self>` impl.
123 /// Called once per layer at model load time.
124 ///
125 /// Inputs are host-side slices (CPU memory) — the loader reads from
126 /// safetensors and hands them off; each backend uploads + repacks
127 /// per its own strategy. `bits` is typically 4; `group_size` is
128 /// typically 128. `bias_host` is optional `[out_features]` f32 (when
129 /// the model has fused bias, e.g. Qwen2.5 attention projections).
130 ///
131 /// Phase 3e/2: returns `Box<dyn Linear<Self>>` directly (CUDA:
132 /// `CudaMarlinLinear`, CPU: `CpuGptqLinear`). Kernel dispatch lives
133 /// inside the boxed Linear's `forward` — the old `gemm_gptq` trait
134 /// method is gone.
135 #[allow(clippy::too_many_arguments)]
136 fn load_gptq(
137 _qweight: &[i32],
138 _scales: &[f32],
139 _qzeros: &[i32],
140 _g_idx: Option<&[i32]>,
141 _bias_host: Option<&[f32]>,
142 _bits: u32,
143 _group_size: usize,
144 _k: usize,
145 _n: usize,
146 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
147 Err(FerrumError::unsupported(
148 "load_gptq not implemented for this backend",
149 ))
150 }
151 /// Load num_experts GPTQ weight tiles into ONE stacked store, with
152 /// the property that **each expert's packed bytes are contiguous**
153 /// in the resulting store. This is what the offset GEMM needs to
154 /// dispatch per expert via pointer offset alone.
155 ///
156 /// Why this is a separate API from `load_gptq` + post-hoc concat:
157 /// Marlin's repack permutes data in `[K-tile-row outer, N-tile inner]`
158 /// order. A single repack of `concat(all experts along N)` produces
159 /// a buffer where expert e's bytes are spread across K-tile-rows,
160 /// NOT contiguous. Per-expert repack-then-concat keeps each
161 /// expert's data in one contiguous block.
162 ///
163 /// `qweights[i] / scales[i] / qzeros[i]` are each expert's raw GPTQ
164 /// tensors. All share the same K + group_size + bits + g_idx.
165 ///
166 /// Default returns Err(unsupported); override on backends with a
167 /// per-expert MoE GPTQ path.
168 /// Phase C step 4e: returns the trait-object `MarlinExpertStack`
169 /// directly. Internally, each backend constructs its own opaque
170 /// repacked tile (Marlin: per-expert-then-concat; CPU: dequantized
171 /// f32 weight slab) and wraps it in the concrete
172 /// `{Cuda,Cpu}MarlinExpertStack` impl.
173 ///
174 /// Removing `Self::GptqStore` from the public API kills the type
175 /// leak that previously forced `ExpertStack<B>` to carry
176 /// `Option<Arc<B::GptqStore>>`. Adding a new Marlin backend now
177 /// only requires implementing this method + a fresh
178 /// `MarlinExpertStack<NewBackend>` impl — no Backend trait edits.
179 #[allow(clippy::too_many_arguments)]
180 fn load_gptq_stacked(
181 _qweights: &[&[i32]],
182 _scales: &[&[f32]],
183 _qzeros: &[&[i32]],
184 _g_idx: Option<&[i32]>,
185 _bits: u32,
186 _group_size: usize,
187 _k: usize,
188 _n_per_expert: usize,
189 ) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
190 Err(FerrumError::unsupported(
191 "load_gptq_stacked not implemented for this backend",
192 ))
193 }
194 // Phase C step 4a: marlin_zero_stacked_workspace — body inlined into
195 // MarlinExpertStack::zero_workspace.
196 // Phase C step 4b: make_stacked_expert_linear — body inlined into
197 // MarlinExpertStack::make_expert_linear.
198 // Phase C step 4c+4d: moe_gemm_phase_batched + moe_gemm_phase_vllm —
199 // bodies inlined into MarlinExpertStack::gemm_phase_batched /
200 // gemm_phase_vllm (concrete impls in quant_linear/{cuda,cpu}_marlin_stack.rs).
201 // Phase C step 4e: make_marlin_expert_stack subsumed by
202 // load_gptq_stacked (now returns the trait object directly).
203 // gemm_gptq_with_offset_strided — body inlined into CpuMarlinExpertStack
204 // (the only remaining caller).
205 /// Pre-grow any backend-internal scratch slots whose size depends
206 /// on `m_total * intermediate_size` (the largest matmul fan-in
207 /// inside `unified_forward_internal`). Default no-op. CUDA
208 /// implements this to grow the perm-aware Marlin gather scratch
209 /// EAGERLY before the caller enters a CUDA-graph capture region —
210 /// `cuLaunchKernel` after a runtime alloc inside a captured
211 /// stream returns `CUDA_ERROR_INVALID_VALUE`.
212 fn pregrow_marlin_gather_scratch(_ctx: &mut Self::Context, _required: usize) {
213 // default: no scratch to pre-grow
214 }
215 // Phase C step 4e: gemm_gptq_with_offset_strided removed —
216 // body inlined into CpuMarlinExpertStack (the only caller after
217 // step 4c moved the multi-stream pool dispatch into the CUDA
218 // free function).
219}
220
221// ════════════════════════════════════════════════════════════════════════
222// BackendQuantGguf capability (Metal GGUF Q4_K / Q6_K / Q8_0)
223// ════════════════════════════════════════════════════════════════════════
224//
225// Metal-specific GGUF k-quant GEMM/GEMV via simdgroup_matmul shaders.
226// CUDA/CPU don't ship GGUF kernels; they `impl BackendQuantGguf for X {}`
227// empty and inherit unsupported defaults. GGUF models targeting non-Metal
228// backends are loaded via dequant-fallback in the Linear impls.
229
230/// Capability-trait for backends that natively dispatch GGUF k-quant
231/// GEMM / GEMV. Metal wires its q4k/q6k shaders here; CUDA/CPU inherit
232/// defaults that error.
233pub trait BackendQuantGguf: Backend {
234 /// Load GGUF k-quant weights into the backend's preferred format.
235 ///
236 /// `kind` discriminates Q4_K / Q5_K / Q6_K / Q8_0 etc. The CPU path
237 /// typically eager-dequants to fp32; the Metal path keeps raw block
238 /// bytes in MTLBuffer and dequants per matmul into a transient fp16
239 /// buffer. Adding a new k-quant flavour is a matched pair of
240 /// `QuantStore` variant + `match` arm, not a new trait method.
241 ///
242 /// `bytes`: contiguous on-disk payload — `n_blocks × block_size`.
243 /// `n_rows`: out_features. `n_cols`: in_features. The block count
244 /// is derived per-kind from these dims.
245 fn load_quant(
246 _kind: GgufQuantType,
247 _bytes: &[u8],
248 _n_rows: usize,
249 _n_cols: usize,
250 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
251 Err(FerrumError::unsupported(
252 "load_quant not implemented for this backend",
253 ))
254 }
255 /// Build a fused linear from multiple `(kind, bytes, n_rows)`
256 /// parts that share `n_cols`. Used by `GgufLoader::load_fused` when
257 /// parts have heterogeneous quant kinds (e.g. Qwen3 qkv_proj where
258 /// q+k are Q4_K but v is Q6_K) — byte-concatenation isn't possible,
259 /// so each part stays as its own QuantStore and the gemm dispatches
260 /// one matvec per part with output offsets.
261 ///
262 /// Phase 3e/3: returns `Box<dyn Linear<Self>>` directly (Metal:
263 /// `MetalGgufLinear` over a `Fused` MetalQuantStore variant).
264 fn load_quant_fused(
265 _parts: &[(GgufQuantType, &[u8], usize)],
266 _n_cols: usize,
267 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
268 Err(FerrumError::unsupported(
269 "load_quant_fused not implemented for this backend",
270 ))
271 }
272 /// Build a stacked-experts MoE linear from a contiguous 3-D weight
273 /// payload `[num_experts, n_rows, n_cols/256]` super-blocks. Used for
274 /// the MoE indirect-dispatch fast path; backends without such a kernel
275 /// return `Err(unsupported)` and the model code falls back to the
276 /// per-expert `Box<dyn Linear<Self>>` loop.
277 ///
278 /// Phase 3e/4: returns `Box<dyn StackedExpertGgufLinear<Self>>` directly
279 /// (Metal: `MetalStackedExpertGgufLinear` over Q4KExperts / Q6KExperts).
280 /// Replaces the old `Result<Self::QuantStore>` API + the 7 sibling
281 /// `*_moe_id*` Backend methods that consumed it.
282 fn load_quant_experts(
283 _kind: GgufQuantType,
284 _bytes: &[u8],
285 _num_experts: usize,
286 _n_rows: usize,
287 _n_cols: usize,
288 ) -> Result<Box<dyn crate::StackedExpertGgufLinear<Self>>> {
289 Err(FerrumError::unsupported(
290 "load_quant_experts not implemented for this backend",
291 ))
292 }
293}
294
295// ════════════════════════════════════════════════════════════════════════
296// BackendMoeFused capability (MoE routing + post-ops kernels)
297// ════════════════════════════════════════════════════════════════════════
298//
299// Backend-specific MoE infrastructure: routing index buffers, expert
300// dispatch align, weighted sum / silu/mul fused ops, top-k softmax.
301// CUDA + Metal both implement (they're the real MoE backends);
302// CPU inherits unsupported defaults.
303
304/// Capability-trait for backends that natively dispatch MoE post-ops + routing.
305pub trait BackendMoeFused: Backend {
306 /// Routing inputs for `moe_gemm_phase_vllm` — host-built i32 arrays
307 /// uploaded once per layer (or per token, depending on caller cadence).
308 /// Matches the shape contract of `moe_align_block_size` outputs but is
309 /// usable on backends that build the indices on host.
310 ///
311 /// Buffers are typed Self::Buffer (= fp16 on CUDA) for trait-object
312 /// reasons; backends reinterpret as i32. Default returns unsupported.
313 fn upload_moe_routing(
314 _ctx: &mut Self::Context,
315 _sorted_token_ids: &[i32],
316 _expert_ids: &[i32],
317 _num_tokens_past_padded: &[i32],
318 ) -> Result<MoeRouting<Self>> {
319 Err(FerrumError::unsupported(
320 "upload_moe_routing not implemented for this backend",
321 ))
322 }
323 /// GPU-side MoE router: `[batch, num_experts]` logits → `[batch, top_k]`
324 /// expert IDs (i32) + `[batch, top_k]` combine weights (f32).
325 ///
326 /// Replaces the per-layer `B::sync + B::to_vec(router_logits) + host route()`
327 /// round trip. The output buffers stay device-side for downstream
328 /// `gemv_quant_moe_id` / `gemm_quant_moe_id` consumption — no host
329 /// pipeline drain in the inner loop.
330 ///
331 /// `norm_topk_prob`: if true, divide each row's K weights by their
332 /// sum so they total 1.0 (Qwen3-MoE / Mixtral default).
333 #[allow(clippy::too_many_arguments)]
334 fn route_topk_softmax(
335 _ctx: &mut Self::Context,
336 _logits: &Self::Buffer,
337 _out_ids: &mut Self::Buffer,
338 _out_weights: &mut Self::Buffer,
339 _batch: usize,
340 _num_experts: usize,
341 _top_k: usize,
342 _norm_topk_prob: bool,
343 ) -> Result<()> {
344 Err(FerrumError::unsupported(
345 "route_topk_softmax not implemented for this backend",
346 ))
347 }
348 /// GPU-side fast-path for the host route() leg of the bucketed
349 /// MoE forward (`moe_forward_bucketed` in ferrum-models). Replaces
350 /// the `B::sync(ctx) + B::to_vec(logits) + crate::moe::router::
351 /// route_into(...)` triple with a single GPU kernel + small D2H of
352 /// `[batch, top_k]` ids + weights.
353 ///
354 /// The backend allocates / reuses its own device-side scratch for
355 /// the kernel output; the caller only provides the host destination
356 /// vectors (resized + overwritten on each call). Default impl
357 /// returns `Err(unsupported)` so non-CUDA callers stay on the host
358 /// route_into() path with no behavior change.
359 #[allow(clippy::too_many_arguments)]
360 fn try_gpu_route_topk_into_host(
361 _ctx: &mut Self::Context,
362 _logits_dev: &Self::Buffer,
363 _out_ids_host: &mut Vec<u32>,
364 _out_weights_host: &mut Vec<f32>,
365 _batch: usize,
366 _num_experts: usize,
367 _top_k: usize,
368 _norm_topk_prob: bool,
369 ) -> Result<()> {
370 Err(FerrumError::unsupported(
371 "try_gpu_route_topk_into_host not implemented for this backend",
372 ))
373 }
374 /// GPU-side moe_align_block_size — prep for a future fused MoE
375 /// Marlin kernel. Takes per-pair expert assignments (from
376 /// [`Self::route_topk_softmax`]) and produces:
377 /// - `sorted_token_ids[N_padded]`: flat list of pair indices
378 /// in [0, batch * top_k), sorted by their assigned expert and
379 /// padded with sentinel `batch * top_k` inside each expert
380 /// group up to a `block_size` boundary.
381 /// - `block_ids[N_padded / block_size]`: which expert each
382 /// `block_size`-row tile of `sorted_token_ids` belongs to.
383 /// - `total_tokens_post_pad[1]`: actual padded token count.
384 ///
385 /// Layout matches vLLM's marlin_moe_wna16 kernel input
386 /// expectation. The fused Marlin kernel reads a row from
387 /// `a[sorted_token_ids[i] / top_k]` and weights from
388 /// Build `pairs_by_token` + `packed_token_idx` device-side from
389 /// device-side `expert_ids`. The counting-sort permutation that
390 /// lets `moe_combine` (and the gather step before phase 1 GEMM)
391 /// read routing output without a host round-trip — the prerequisite
392 /// for graph-capturing the MoE bucketed path.
393 ///
394 /// Inputs (device):
395 /// - `expert_ids: I32 [batch * top_k]` — top-K selected expert ids.
396 ///
397 /// Outputs (device):
398 /// - `pairs_by_token: I32 [batch * top_k]` — sorted-by-expert
399 /// position of each (b, k) pair (the row index into `packed_down`
400 /// that `moe_combine` reads).
401 /// - `packed_token_idx: I32 [batch * top_k]` — inverse map: for
402 /// each packed row, the original token b. Used by the gather
403 /// step (`embedding_lookup` of `x` into `x_packed` before phase 1).
404 /// - `expert_offsets: I32 [num_experts + 1]` — exclusive prefix
405 /// sum of tokens-per-expert; phase 1/3 dispatchers use it to
406 /// compute each expert's row slice in the packed buffers.
407 ///
408 /// Default impl returns Err — only CUDA implements this.
409 #[allow(clippy::too_many_arguments)]
410 fn moe_build_pairs_by_token(
411 _ctx: &mut Self::Context,
412 _expert_ids: &Self::Buffer,
413 _pairs_by_token: &mut Self::Buffer,
414 _packed_token_idx: &mut Self::Buffer,
415 _expert_offsets: &mut Self::Buffer,
416 _batch_x_topk: usize,
417 _num_experts: usize,
418 _top_k: usize,
419 ) -> Result<()> {
420 Err(FerrumError::unsupported(
421 "moe_build_pairs_by_token not implemented for this backend",
422 ))
423 }
424
425 /// `b[block_ids[blockIdx.y] * n_per_expert + ...]`.
426 ///
427 /// Default impl returns Err — only CUDA implements this.
428 #[allow(clippy::too_many_arguments)]
429 fn moe_align_block_size(
430 _ctx: &mut Self::Context,
431 _expert_ids_per_pair: &Self::Buffer,
432 _sorted_token_ids: &mut Self::Buffer,
433 _block_ids: &mut Self::Buffer,
434 _total_tokens_post_pad: &mut Self::Buffer,
435 _batch_x_topk: usize,
436 _num_experts: usize,
437 _block_size: usize,
438 _sorted_max_size: usize,
439 ) -> Result<()> {
440 Err(FerrumError::unsupported(
441 "moe_align_block_size not implemented for this backend",
442 ))
443 }
444
445 /// vLLM-native align variant: `sorted_token_ids` stores flattened
446 /// `(token, top_k_slot)` pair ids, not Ferrum's pre-gathered packed rows.
447 /// This lets marlin_moe read gate_up input as `A[pair_id / top_k]`.
448 #[allow(clippy::too_many_arguments)]
449 fn moe_align_block_size_pair_ids(
450 _ctx: &mut Self::Context,
451 _expert_ids_per_pair: &Self::Buffer,
452 _sorted_token_ids: &mut Self::Buffer,
453 _block_ids: &mut Self::Buffer,
454 _total_tokens_post_pad: &mut Self::Buffer,
455 _batch_x_topk: usize,
456 _num_experts: usize,
457 _block_size: usize,
458 _sorted_max_size: usize,
459 ) -> Result<()> {
460 Err(FerrumError::unsupported(
461 "moe_align_block_size_pair_ids not implemented for this backend",
462 ))
463 }
464
465 /// GPU-side bucket sort: turn `[batch, top_k]` selected expert IDs
466 /// (from [`Self::route_topk_softmax`]) into `tpe[num_experts]` /
467 /// `ids[num_experts * row_stride]` arrays consumed by the batched
468 /// MoE GEMM, and emit indirect-dispatch args for the consumer GEMM.
469 ///
470 /// The `ids` buffer's row stride is `batch * top_k` (worst case);
471 /// only the first `tpe[e]` entries of each row are populated. The
472 /// consumer GEMM kernel early-exits at `r1 >= tpe[e]`, so the over-
473 /// strided indices cost nothing in the inner loop. The grid size,
474 /// however, would still be worst-case unless we tighten it — this
475 /// is what the `gate_up_args` / `down_args` outputs do: a 12-byte
476 /// `(grid_x, grid_y, grid_z)` u32 triple per shape, ready for
477 /// `dispatch_thread_groups_indirect`. `grid_x` is shared (depends
478 /// only on `max(tpe[e])`); `grid_y` differs because gate/up has
479 /// `M = m_gate_up` while down has `M = m_down`.
480 ///
481 /// All five output buffers are written in one kernel; no host
482 /// roundtrip and no per-layer pipeline drain.
483 #[allow(clippy::too_many_arguments)]
484 fn compute_ids_tpe_gpu(
485 _ctx: &mut Self::Context,
486 _selected_ids: &Self::Buffer,
487 _tpe: &mut Self::Buffer,
488 _ids: &mut Self::Buffer,
489 _gate_up_args: &mut Self::Buffer,
490 _down_args: &mut Self::Buffer,
491 _batch: usize,
492 _num_experts: usize,
493 _top_k: usize,
494 _m_gate_up: usize,
495 _m_down: usize,
496 ) -> Result<()> {
497 Err(FerrumError::unsupported(
498 "compute_ids_tpe_gpu not implemented for this backend",
499 ))
500 }
501 /// Stacked SiLU·gate over `[batch * top_k, ffn]` rows (prefill version
502 /// of `silu_mul_stacked`).
503 fn silu_mul_batched(
504 _ctx: &mut Self::Context,
505 _gate: &Self::Buffer,
506 _up: &Self::Buffer,
507 _out: &mut Self::Buffer,
508 _total_pairs: usize,
509 _ffn: usize,
510 ) -> Result<()> {
511 Err(FerrumError::unsupported(
512 "silu_mul_batched not implemented for this backend",
513 ))
514 }
515 /// Fused weighted-sum + residual-add: `residual[i] += Σ_k weights[k] · slots[k, i]`.
516 /// Single dispatch replaces the (weighted_sum → moe_out) +
517 /// (add_inplace residual += moe_out) pair on the decode hot path.
518 fn weighted_sum_residual_stacked(
519 _ctx: &mut Self::Context,
520 _slots: &Self::Buffer,
521 _weights: &Self::Buffer,
522 _residual: &mut Self::Buffer,
523 _n_slots: usize,
524 _hidden: usize,
525 ) -> Result<()> {
526 Err(FerrumError::unsupported(
527 "weighted_sum_residual_stacked not implemented for this backend",
528 ))
529 }
530 /// Fused weighted-sum-residual + RMSNorm: combines this layer's
531 /// `weighted_sum_residual_stacked` with the next layer's leading
532 /// `rms_norm` into a single dispatch.
533 ///
534 /// Computes
535 /// `residual[i] += Σ_s w[s] · slots[s, i]`
536 /// `normed_out[i] = residual[i] · (1 / sqrt(Σ residual² / hidden + eps)) · next_norm_w[i]`
537 ///
538 /// Caller is responsible for skipping the next layer's standalone
539 /// `rms_norm` — `normed_out` IS that layer's `norm_out` input.
540 /// Default returns Unsupported.
541 #[allow(clippy::too_many_arguments)]
542 fn weighted_sum_residual_norm_stacked(
543 _ctx: &mut Self::Context,
544 _slots: &Self::Buffer,
545 _weights: &Self::Buffer,
546 _residual: &mut Self::Buffer,
547 _next_norm_w: &Self::Buffer,
548 _normed_out: &mut Self::Buffer,
549 _n_slots: usize,
550 _hidden: usize,
551 _eps: f32,
552 ) -> Result<()> {
553 Err(FerrumError::unsupported(
554 "weighted_sum_residual_norm_stacked not implemented for this backend",
555 ))
556 }
557 /// Per-batch weighted sum: `out[b, h] = Σ_k weights[b, k] · slots[b, k, h]`.
558 /// Single dispatch covers the whole batch (prefill version of
559 /// `weighted_sum_stacked` which only handled one token).
560 fn weighted_sum_batched(
561 _ctx: &mut Self::Context,
562 _slots: &Self::Buffer,
563 _weights: &Self::Buffer,
564 _out: &mut Self::Buffer,
565 _batch: usize,
566 _top_k: usize,
567 _hidden: usize,
568 ) -> Result<()> {
569 Err(FerrumError::unsupported(
570 "weighted_sum_batched not implemented for this backend",
571 ))
572 }
573 /// Offset-aware variant of [`Self::weighted_sum_batched`] —
574 /// `weights` reads from `weights_offset` (in elements, points at
575 /// the start of `[batch, top_k]`), `out` writes from `out_offset`
576 /// (in elements, points at start of `[batch, hidden]`). Used by
577 /// the per-item batched-decode path to skip `copy_slice` round-trips.
578 /// Default falls back to the non-offset variant via two copies.
579 #[allow(clippy::too_many_arguments)]
580 fn weighted_sum_batched_offset(
581 ctx: &mut Self::Context,
582 slots: &Self::Buffer,
583 weights: &Self::Buffer,
584 weights_offset: usize,
585 out: &mut Self::Buffer,
586 out_offset: usize,
587 batch: usize,
588 top_k: usize,
589 hidden: usize,
590 ) -> Result<()> {
591 // Default: stage through scratch — backends override for zero-copy.
592 let _ = (
593 ctx,
594 slots,
595 weights,
596 weights_offset,
597 out,
598 out_offset,
599 batch,
600 top_k,
601 hidden,
602 );
603 Err(FerrumError::unsupported(
604 "weighted_sum_batched_offset not implemented for this backend",
605 ))
606 }
607 /// Stacked SiLU·gate over `[n_slots, ffn]` rows.
608 ///
609 /// Computes `out[s, i] = silu(gate[s, i]) * up[s, i]` for each slot
610 /// `s`, element `i`. Single dispatch covers all slots — cuts the
611 /// MoE decode silu staging from `top_k * (3 copy_slice + 1 silu)`
612 /// = 32 dispatches per layer to 1.
613 fn silu_mul_stacked(
614 _ctx: &mut Self::Context,
615 _gate: &Self::Buffer,
616 _up: &Self::Buffer,
617 _out: &mut Self::Buffer,
618 _n_slots: usize,
619 _ffn: usize,
620 ) -> Result<()> {
621 Err(FerrumError::unsupported(
622 "silu_mul_stacked not implemented for this backend",
623 ))
624 }
625 /// Capability probe for [`Self::gemv_quant_moe_id_gate_up_silu`].
626 ///
627 /// `true` ⇒ the fused kernel is wired in and the caller should
628 /// prefer it on the MoE decode hot path. `false` ⇒ caller must use
629 /// the 3-dispatch fallback (gate gemv + up gemv + silu_mul_stacked).
630 /// Lets callers branch without paying the cost of an `Err(Unsupported)`
631 /// allocation per (layer, step).
632 fn supports_fused_moe_gate_up_silu() -> bool {
633 false
634 }
635 /// Capability probe for [`Self::gemv_quant_moe_id_batched`].
636 fn supports_batched_moe_gemv() -> bool {
637 false
638 }
639 /// Capability probe for [`Self::gemv_quant_moe_id_gate_up_silu_batched`].
640 fn supports_batched_moe_gate_up_silu() -> bool {
641 false
642 }
643 /// Weighted sum across `n_slots` rows of `[hidden]`.
644 ///
645 /// Computes `out[i] = Σ_s weights[s] * slots[s, i]`. Single
646 /// dispatch replaces the per-slot `(copy_slice + scaled_add)`
647 /// loop in the MoE decode path (16 dispatches per layer → 1).
648 fn weighted_sum_stacked(
649 _ctx: &mut Self::Context,
650 _slots: &Self::Buffer,
651 _weights: &Self::Buffer,
652 _out: &mut Self::Buffer,
653 _n_slots: usize,
654 _hidden: usize,
655 ) -> Result<()> {
656 Err(FerrumError::unsupported(
657 "weighted_sum_stacked not implemented for this backend",
658 ))
659 }
660 /// MoE combine: per-token weighted sum across `top_k` expert outputs.
661 ///
662 /// After the bucketed dispatch, `packed_down` holds `[total_pairs,
663 /// hidden]` with one row per (token, k_slot) pair in expert-bucketed
664 /// order. `pairs_by_token[b * top_k + k]` is the inverse map: which
665 /// row of `packed_down` carries the (b, k_slot) contribution. A row
666 /// index of `-1` means "skip" (unused slot).
667 ///
668 /// Computes:
669 ///
670 /// ```text
671 /// out[b, h] = sum_k pair_weights[b * top_k + k] *
672 /// packed_down[pairs_by_token[b * top_k + k], h]
673 /// ```
674 ///
675 /// Default impl round-trips via host memory — correct but slow.
676 /// CUDA backend launches a single fused kernel.
677 ///
678 /// Phase D follow-up: `pairs_by_token` (I32) and `pair_weights` (F32)
679 /// are now device buffers so callers can build them on-device for
680 /// graph capture (was `&[i32]` / `&[f32]` host slices with internal
681 /// clone_htod, which records stale host pointers under CUDA Graph
682 /// capture replay).
683 #[allow(clippy::too_many_arguments)]
684 fn moe_combine(
685 ctx: &mut Self::Context,
686 packed_down: &Self::Buffer,
687 pairs_by_token: &Self::Buffer,
688 pair_weights: &Self::Buffer,
689 out: &mut Self::Buffer,
690 batch: usize,
691 hidden: usize,
692 top_k: usize,
693 total_pairs: usize,
694 ) {
695 // Reference default: D2H pairs/weights, run the host loop, H2D out.
696 // CUDA backend overrides with a single device kernel.
697 let _ = ctx;
698 let packed = Self::to_vec(packed_down, total_pairs * hidden);
699 let pairs_host_f32 = Self::to_vec(pairs_by_token, batch * top_k);
700 let weights_host = Self::to_vec(pair_weights, batch * top_k);
701 let mut out_h = vec![0.0f32; batch * hidden];
702 for b in 0..batch {
703 for k in 0..top_k {
704 // `to_vec` returns f32; the device-side I32 buffer is
705 // bit-cast to f32 by the trait's f16-default to_vec path,
706 // so we re-extract via raw transmute. Backends override
707 // this default with a typed kernel that doesn't go
708 // through f16; on the default path callers are CPU
709 // parity tests where the byte pattern is preserved.
710 let pair_row = pairs_host_f32[b * top_k + k].to_bits() as i32;
711 if pair_row < 0 {
712 continue;
713 }
714 let w = weights_host[b * top_k + k];
715 let src = &packed[(pair_row as usize) * hidden..(pair_row as usize + 1) * hidden];
716 let dst = &mut out_h[b * hidden..(b + 1) * hidden];
717 for h in 0..hidden {
718 dst[h] += w * src[h];
719 }
720 }
721 }
722 *out = Self::from_slice(&out_h);
723 }
724}