pub trait MarlinExpertStack<B: Backend>: Send + Sync {
// Required methods
fn n_per_expert(&self) -> usize;
fn k(&self) -> usize;
fn num_experts(&self) -> usize;
fn as_any(&self) -> &dyn Any;
fn zero_workspace(&self, ctx: &mut B::Context) -> Result<()>;
fn gemm_phase_batched(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
dispatches: &[(usize, usize, usize, usize)],
output: &mut B::Buffer,
k: usize,
) -> Result<()>;
fn make_expert_linear(
self: Arc<Self>,
expert_offset: usize,
expert_n: usize,
bias_host: Option<&[f32]>,
) -> Result<Box<dyn Linear<B> + Send + Sync>>;
// Provided method
fn gemm_phase_vllm(
&self,
_ctx: &mut B::Context,
_input: &B::Buffer,
_sorted_token_ids: &B::Buffer,
_expert_ids: &B::Buffer,
_num_tokens_past_padded: &B::Buffer,
_output: &mut B::Buffer,
_prob_m: usize,
_moe_block_size: usize,
_top_k: usize,
) -> Result<()> { ... }
}Expand description
MoE-stacked Marlin INT4 expert tile: holds N experts’ weights for one matmul role (gate_up / down) in one contiguous repacked Marlin buffer, dispatches per-expert column-slice GEMMs in a single fused launch (vLLM marlin_moe_wna16) or as a bucketed batched call.
Required Methods§
Sourcefn n_per_expert(&self) -> usize
fn n_per_expert(&self) -> usize
Per-expert output width (N tile cols).
Sourcefn num_experts(&self) -> usize
fn num_experts(&self) -> usize
Number of experts packed into the tile.
Sourcefn as_any(&self) -> &dyn Any
fn as_any(&self) -> &dyn Any
Downcast hook — used at FFN dispatch boundaries where the
caller needs to reach into the concrete store to e.g. share
workspace memory across phases. Standard dyn Any pattern.
Sourcefn zero_workspace(&self, ctx: &mut B::Context) -> Result<()>
fn zero_workspace(&self, ctx: &mut B::Context) -> Result<()>
Bulk-zero the per-expert Marlin workspace mutex slots. Call ONCE
before a batch of bucketed gemm_phase_batched calls — saves
the per-call cuMemsetD32Async (one launch each → one launch
total). At c=32 with 128 active experts × 2 phases × 48 layers
that’s ~12k memset launches/token reduced to ~96.
Sourcefn gemm_phase_batched(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
dispatches: &[(usize, usize, usize, usize)],
output: &mut B::Buffer,
k: usize,
) -> Result<()>
fn gemm_phase_batched( &self, ctx: &mut B::Context, input: &B::Buffer, dispatches: &[(usize, usize, usize, usize)], output: &mut B::Buffer, k: usize, ) -> Result<()>
Batched per-expert offset GEMM. dispatches[i] = (expert_idx, in_row_offset, out_row_offset, m). Runs each
expert’s (m × K) @ tile[expert] = m × n_per_expert slice;
CUDA backend overlaps via multi-stream round-robin.
Sourcefn make_expert_linear(
self: Arc<Self>,
expert_offset: usize,
expert_n: usize,
bias_host: Option<&[f32]>,
) -> Result<Box<dyn Linear<B> + Send + Sync>>
fn make_expert_linear( self: Arc<Self>, expert_offset: usize, expert_n: usize, bias_host: Option<&[f32]>, ) -> Result<Box<dyn Linear<B> + Send + Sync>>
Build a single-expert Linear<B> view onto this stack’s
[expert_offset .. expert_offset + expert_n) column slice.
Used for per-expert dispatch outside the MoE phase batching
(e.g. shared-experts code paths). expert_offset and expert_n
MUST be multiples of the backend’s Marlin N tile (64 on CUDA).
Provided Methods§
Sourcefn gemm_phase_vllm(
&self,
_ctx: &mut B::Context,
_input: &B::Buffer,
_sorted_token_ids: &B::Buffer,
_expert_ids: &B::Buffer,
_num_tokens_past_padded: &B::Buffer,
_output: &mut B::Buffer,
_prob_m: usize,
_moe_block_size: usize,
_top_k: usize,
) -> Result<()>
fn gemm_phase_vllm( &self, _ctx: &mut B::Context, _input: &B::Buffer, _sorted_token_ids: &B::Buffer, _expert_ids: &B::Buffer, _num_tokens_past_padded: &B::Buffer, _output: &mut B::Buffer, _prob_m: usize, _moe_block_size: usize, _top_k: usize, ) -> Result<()>
vLLM marlin_moe_wna16 fused GEMM (single launch, per-block
expert routing inside the kernel). Caller responsibilities:
outputMUST be pre-zeroed (atomic-add path doesn’t self-zero).sorted_token_ids/expert_ids/num_tokens_past_paddedcome frommoe_align_block_size.prob_mis the unique-token count (top_k=1 with pre-gathered rows ⇒ equalstotal_pairs). Backends without vLLM Marlin returnErr(unsupported).
Dyn Compatibility§
This trait is dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".