Skip to main content

MarlinExpertStack

Trait MarlinExpertStack 

Source
pub trait MarlinExpertStack<B: Backend>: Send + Sync {
    // Required methods
    fn n_per_expert(&self) -> usize;
    fn k(&self) -> usize;
    fn num_experts(&self) -> usize;
    fn as_any(&self) -> &dyn Any;
    fn zero_workspace(&self, ctx: &mut B::Context) -> Result<()>;
    fn gemm_phase_batched(
        &self,
        ctx: &mut B::Context,
        input: &B::Buffer,
        dispatches: &[(usize, usize, usize, usize)],
        output: &mut B::Buffer,
        k: usize,
    ) -> Result<()>;
    fn make_expert_linear(
        self: Arc<Self>,
        expert_offset: usize,
        expert_n: usize,
        bias_host: Option<&[f32]>,
    ) -> Result<Box<dyn Linear<B> + Send + Sync>>;

    // Provided method
    fn gemm_phase_vllm(
        &self,
        _ctx: &mut B::Context,
        _input: &B::Buffer,
        _sorted_token_ids: &B::Buffer,
        _expert_ids: &B::Buffer,
        _num_tokens_past_padded: &B::Buffer,
        _output: &mut B::Buffer,
        _prob_m: usize,
        _moe_block_size: usize,
        _top_k: usize,
    ) -> Result<()> { ... }
}
Expand description

MoE-stacked Marlin INT4 expert tile: holds N experts’ weights for one matmul role (gate_up / down) in one contiguous repacked Marlin buffer, dispatches per-expert column-slice GEMMs in a single fused launch (vLLM marlin_moe_wna16) or as a bucketed batched call.

Required Methods§

Source

fn n_per_expert(&self) -> usize

Per-expert output width (N tile cols).

Source

fn k(&self) -> usize

Input width (K), common across experts.

Source

fn num_experts(&self) -> usize

Number of experts packed into the tile.

Source

fn as_any(&self) -> &dyn Any

Downcast hook — used at FFN dispatch boundaries where the caller needs to reach into the concrete store to e.g. share workspace memory across phases. Standard dyn Any pattern.

Source

fn zero_workspace(&self, ctx: &mut B::Context) -> Result<()>

Bulk-zero the per-expert Marlin workspace mutex slots. Call ONCE before a batch of bucketed gemm_phase_batched calls — saves the per-call cuMemsetD32Async (one launch each → one launch total). At c=32 with 128 active experts × 2 phases × 48 layers that’s ~12k memset launches/token reduced to ~96.

Source

fn gemm_phase_batched( &self, ctx: &mut B::Context, input: &B::Buffer, dispatches: &[(usize, usize, usize, usize)], output: &mut B::Buffer, k: usize, ) -> Result<()>

Batched per-expert offset GEMM. dispatches[i] = (expert_idx, in_row_offset, out_row_offset, m). Runs each expert’s (m × K) @ tile[expert] = m × n_per_expert slice; CUDA backend overlaps via multi-stream round-robin.

Source

fn make_expert_linear( self: Arc<Self>, expert_offset: usize, expert_n: usize, bias_host: Option<&[f32]>, ) -> Result<Box<dyn Linear<B> + Send + Sync>>

Build a single-expert Linear<B> view onto this stack’s [expert_offset .. expert_offset + expert_n) column slice. Used for per-expert dispatch outside the MoE phase batching (e.g. shared-experts code paths). expert_offset and expert_n MUST be multiples of the backend’s Marlin N tile (64 on CUDA).

Provided Methods§

Source

fn gemm_phase_vllm( &self, _ctx: &mut B::Context, _input: &B::Buffer, _sorted_token_ids: &B::Buffer, _expert_ids: &B::Buffer, _num_tokens_past_padded: &B::Buffer, _output: &mut B::Buffer, _prob_m: usize, _moe_block_size: usize, _top_k: usize, ) -> Result<()>

vLLM marlin_moe_wna16 fused GEMM (single launch, per-block expert routing inside the kernel). Caller responsibilities:

  • output MUST be pre-zeroed (atomic-add path doesn’t self-zero).
  • sorted_token_ids / expert_ids / num_tokens_past_padded come from moe_align_block_size.
  • prob_m is the unique-token count (top_k=1 with pre-gathered rows ⇒ equals total_pairs). Backends without vLLM Marlin return Err(unsupported).

Dyn Compatibility§

This trait is dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§