ferrum_kernels/marlin_expert_stack.rs
1//! `MarlinExpertStack<B>` — abstraction for "N MoE experts' Marlin
2//! GPTQ-INT4 tiles stored contiguously, dispatched as bucketed batched
3//! GEMM or vLLM fused MoE kernel".
4//!
5//! Phase C sibling to `StackedExpertGgufLinear<B>` (GGUF) and `Linear<B>`
6//! (single-tensor). Same goal: drop `type GptqStore` from the `Backend`
7//! trait by routing dispatch through a `Box<dyn MarlinExpertStack<B>>`
8//! returned by the loader — so future backends only need to implement
9//! this trait, not edit the `Backend` supertrait stack.
10//!
11//! Concrete impls (added in Phase C step 2):
12//! - `quant_linear::cuda_marlin_stack::CudaMarlinExpertStack` wraps
13//! `Arc<GptqStoreCuda>` and dispatches to `marlin_gemm_with_offset_strided`
14//! (bucketed) or `marlin_moe_wna16` (vLLM fused).
15//! - CPU dequant path stays per-Linear (no batched MoE Marlin kernel).
16//!
17//! The trait surface is intentionally small — three GEMM methods + a
18//! workspace zero + an expert-view constructor. Each maps 1:1 to an
19//! existing `Backend::moe_gemm_phase_*` method that Phase C step 3
20//! will delete from the trait.
21
22use crate::backend::Backend;
23use crate::Linear;
24use ferrum_types::Result;
25use std::sync::Arc;
26
27/// MoE-stacked Marlin INT4 expert tile: holds N experts' weights for one
28/// matmul role (gate_up / down) in one contiguous repacked Marlin buffer,
29/// dispatches per-expert column-slice GEMMs in a single fused launch
30/// (vLLM marlin_moe_wna16) or as a bucketed batched call.
31pub trait MarlinExpertStack<B: Backend>: Send + Sync {
32 /// Per-expert output width (N tile cols).
33 fn n_per_expert(&self) -> usize;
34 /// Input width (K), common across experts.
35 fn k(&self) -> usize;
36 /// Number of experts packed into the tile.
37 fn num_experts(&self) -> usize;
38
39 /// Downcast hook — used at FFN dispatch boundaries where the
40 /// caller needs to reach into the concrete store to e.g. share
41 /// workspace memory across phases. Standard `dyn Any` pattern.
42 fn as_any(&self) -> &dyn std::any::Any;
43
44 /// Bulk-zero the per-expert Marlin workspace mutex slots. Call ONCE
45 /// before a batch of bucketed `gemm_phase_batched` calls — saves
46 /// the per-call cuMemsetD32Async (one launch each → one launch
47 /// total). At c=32 with 128 active experts × 2 phases × 48 layers
48 /// that's ~12k memset launches/token reduced to ~96.
49 fn zero_workspace(&self, ctx: &mut B::Context) -> Result<()>;
50
51 /// Batched per-expert offset GEMM. `dispatches[i] =
52 /// (expert_idx, in_row_offset, out_row_offset, m)`. Runs each
53 /// expert's `(m × K) @ tile[expert] = m × n_per_expert` slice;
54 /// CUDA backend overlaps via multi-stream round-robin.
55 #[allow(clippy::too_many_arguments)]
56 fn gemm_phase_batched(
57 &self,
58 ctx: &mut B::Context,
59 input: &B::Buffer,
60 dispatches: &[(usize, usize, usize, usize)],
61 output: &mut B::Buffer,
62 k: usize,
63 ) -> Result<()>;
64
65 /// vLLM `marlin_moe_wna16` fused GEMM (single launch, per-block
66 /// expert routing inside the kernel). Caller responsibilities:
67 /// - `output` MUST be pre-zeroed (atomic-add path doesn't self-zero).
68 /// - `sorted_token_ids` / `expert_ids` / `num_tokens_past_padded`
69 /// come from `moe_align_block_size`.
70 /// - `prob_m` is the unique-token count (top_k=1 with pre-gathered
71 /// rows ⇒ equals `total_pairs`).
72 /// Backends without vLLM Marlin return `Err(unsupported)`.
73 #[allow(clippy::too_many_arguments)]
74 fn gemm_phase_vllm(
75 &self,
76 _ctx: &mut B::Context,
77 _input: &B::Buffer,
78 _sorted_token_ids: &B::Buffer,
79 _expert_ids: &B::Buffer,
80 _num_tokens_past_padded: &B::Buffer,
81 _output: &mut B::Buffer,
82 _prob_m: usize,
83 _moe_block_size: usize,
84 _top_k: usize,
85 ) -> Result<()> {
86 Err(ferrum_types::FerrumError::unsupported(
87 "MarlinExpertStack::gemm_phase_vllm not implemented for this backend",
88 ))
89 }
90
91 /// Build a single-expert `Linear<B>` view onto this stack's
92 /// `[expert_offset .. expert_offset + expert_n)` column slice.
93 /// Used for per-expert dispatch outside the MoE phase batching
94 /// (e.g. shared-experts code paths). `expert_offset` and `expert_n`
95 /// MUST be multiples of the backend's Marlin N tile (64 on CUDA).
96 fn make_expert_linear(
97 self: Arc<Self>,
98 expert_offset: usize,
99 expert_n: usize,
100 bias_host: Option<&[f32]>,
101 ) -> Result<Box<dyn Linear<B> + Send + Sync>>;
102}