Skip to main content

ferrum_kernels/
marlin_expert_stack.rs

1//! `MarlinExpertStack<B>` — abstraction for "N MoE experts' Marlin
2//! GPTQ-INT4 tiles stored contiguously, dispatched as bucketed batched
3//! GEMM or vLLM fused MoE kernel".
4//!
5//! Phase C sibling to `StackedExpertGgufLinear<B>` (GGUF) and `Linear<B>`
6//! (single-tensor). Same goal: drop `type GptqStore` from the `Backend`
7//! trait by routing dispatch through a `Box<dyn MarlinExpertStack<B>>`
8//! returned by the loader — so future backends only need to implement
9//! this trait, not edit the `Backend` supertrait stack.
10//!
11//! Concrete impls (added in Phase C step 2):
12//!   - `quant_linear::cuda_marlin_stack::CudaMarlinExpertStack` wraps
13//!     `Arc<GptqStoreCuda>` and dispatches to `marlin_gemm_with_offset_strided`
14//!     (bucketed) or `marlin_moe_wna16` (vLLM fused).
15//!   - CPU dequant path stays per-Linear (no batched MoE Marlin kernel).
16//!
17//! The trait surface is intentionally small — three GEMM methods + a
18//! workspace zero + an expert-view constructor. Each maps 1:1 to an
19//! existing `Backend::moe_gemm_phase_*` method that Phase C step 3
20//! will delete from the trait.
21
22use crate::backend::Backend;
23use crate::Linear;
24use ferrum_types::Result;
25use std::sync::Arc;
26
27/// MoE-stacked Marlin INT4 expert tile: holds N experts' weights for one
28/// matmul role (gate_up / down) in one contiguous repacked Marlin buffer,
29/// dispatches per-expert column-slice GEMMs in a single fused launch
30/// (vLLM marlin_moe_wna16) or as a bucketed batched call.
31pub trait MarlinExpertStack<B: Backend>: Send + Sync {
32    /// Per-expert output width (N tile cols).
33    fn n_per_expert(&self) -> usize;
34    /// Input width (K), common across experts.
35    fn k(&self) -> usize;
36    /// Number of experts packed into the tile.
37    fn num_experts(&self) -> usize;
38
39    /// Downcast hook — used at FFN dispatch boundaries where the
40    /// caller needs to reach into the concrete store to e.g. share
41    /// workspace memory across phases. Standard `dyn Any` pattern.
42    fn as_any(&self) -> &dyn std::any::Any;
43
44    /// Bulk-zero the per-expert Marlin workspace mutex slots. Call ONCE
45    /// before a batch of bucketed `gemm_phase_batched` calls — saves
46    /// the per-call cuMemsetD32Async (one launch each → one launch
47    /// total). At c=32 with 128 active experts × 2 phases × 48 layers
48    /// that's ~12k memset launches/token reduced to ~96.
49    fn zero_workspace(&self, ctx: &mut B::Context) -> Result<()>;
50
51    /// Batched per-expert offset GEMM. `dispatches[i] =
52    /// (expert_idx, in_row_offset, out_row_offset, m)`. Runs each
53    /// expert's `(m × K) @ tile[expert] = m × n_per_expert` slice;
54    /// CUDA backend overlaps via multi-stream round-robin.
55    #[allow(clippy::too_many_arguments)]
56    fn gemm_phase_batched(
57        &self,
58        ctx: &mut B::Context,
59        input: &B::Buffer,
60        dispatches: &[(usize, usize, usize, usize)],
61        output: &mut B::Buffer,
62        k: usize,
63    ) -> Result<()>;
64
65    /// vLLM `marlin_moe_wna16` fused GEMM (single launch, per-block
66    /// expert routing inside the kernel). Caller responsibilities:
67    /// - `output` MUST be pre-zeroed (atomic-add path doesn't self-zero).
68    /// - `sorted_token_ids` / `expert_ids` / `num_tokens_past_padded`
69    ///   come from `moe_align_block_size`.
70    /// - `prob_m` is the unique-token count (top_k=1 with pre-gathered
71    ///   rows ⇒ equals `total_pairs`).
72    /// Backends without vLLM Marlin return `Err(unsupported)`.
73    #[allow(clippy::too_many_arguments)]
74    fn gemm_phase_vllm(
75        &self,
76        _ctx: &mut B::Context,
77        _input: &B::Buffer,
78        _sorted_token_ids: &B::Buffer,
79        _expert_ids: &B::Buffer,
80        _num_tokens_past_padded: &B::Buffer,
81        _output: &mut B::Buffer,
82        _prob_m: usize,
83        _moe_block_size: usize,
84        _top_k: usize,
85    ) -> Result<()> {
86        Err(ferrum_types::FerrumError::unsupported(
87            "MarlinExpertStack::gemm_phase_vllm not implemented for this backend",
88        ))
89    }
90
91    /// Build a single-expert `Linear<B>` view onto this stack's
92    /// `[expert_offset .. expert_offset + expert_n)` column slice.
93    /// Used for per-expert dispatch outside the MoE phase batching
94    /// (e.g. shared-experts code paths). `expert_offset` and `expert_n`
95    /// MUST be multiples of the backend's Marlin N tile (64 on CUDA).
96    fn make_expert_linear(
97        self: Arc<Self>,
98        expert_offset: usize,
99        expert_n: usize,
100        bias_host: Option<&[f32]>,
101    ) -> Result<Box<dyn Linear<B> + Send + Sync>>;
102}