pub struct ExpertStack<B: QuantLlmBackend + BackendMoeFused> {
pub gate_up: Vec<Box<dyn Linear<B>>>,
pub down: Vec<Box<dyn Linear<B>>>,
pub gate_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
pub up_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
pub down_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
pub gate_up_marlin_stack: Option<Arc<dyn MarlinExpertStack<B>>>,
pub down_marlin_stack: Option<Arc<dyn MarlinExpertStack<B>>>,
}Expand description
Per-layer expert weights, materialised as [num_experts]-long vectors
of Box<dyn Linear<B>>. Each entry runs the corresponding expert’s
fused [gate; up] projection or its down projection.
B::Buffer is hidden behind Linear<B> so this struct is generic
over backend. Production (Qwen3MoeModel::forward) dispatches through
the generic moe_forward<B> (this file, line ~960) and
moe_forward_bucketed<B>; the CPU-only moe_forward_cpu is the
reference path used by parity tests + Qwen3MoeLayer::forward_cpu.
Fields§
§gate_up: Vec<Box<dyn Linear<B>>>Fused [gate; up] projection per expert. Output shape per token:
[2 * expert_intermediate] — the lower half is gate, upper is up.
down: Vec<Box<dyn Linear<B>>>down projection per expert. Output shape per token: [hidden_size].
gate_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>Stacked-experts representation for backends that have a batched
MoE indirect-dispatch kernel (Metal gemv_q4kw_moe_id_f32 /
gemv_q6kw_moe_id_f32). Holds all experts for one matmul
role behind a StackedExpertGgufLinear<B> (typically backed by a
single GPU buffer with byte stride between expert slabs), so a
single dispatch can cover all selected (token, expert) pairs at
decode m=1.
None on backends without the kernel (CPU, CUDA-without-MoE-kernel)
and on quant flavours that don’t have a stacked path yet — callers
fall back to the per-expert gate_up / down Linears in those
cases.
up_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>§down_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>§gate_up_marlin_stack: Option<Arc<dyn MarlinExpertStack<B>>>Stacked Marlin GPTQ expert tiles for the bucketed CUDA path.
When both are Some, moe_forward_bucketed dispatches expert
GEMMs through trait-object methods (store.gemm_phase_* /
store.zero_workspace). None on CPU / Metal / GGUF.
Phase C step 3: replaces Option<Arc<B::GptqStore>> with a
Box<dyn MarlinExpertStack<B>> trait object — kills the
type GptqStore leak through the model layer.
down_marlin_stack: Option<Arc<dyn MarlinExpertStack<B>>>Implementations§
Source§impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B>
impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B>
Sourcepub fn gate_up_stacked_store(
&self,
_expert_idx: usize,
) -> Option<&Arc<dyn MarlinExpertStack<B>>>
pub fn gate_up_stacked_store( &self, _expert_idx: usize, ) -> Option<&Arc<dyn MarlinExpertStack<B>>>
Returns the shared stacked Marlin expert tile for gate_up if
loaded via the bucketed/Marlin path. Used by
moe_forward_bucketed.
Sourcepub fn down_stacked_store(
&self,
_expert_idx: usize,
) -> Option<&Arc<dyn MarlinExpertStack<B>>>
pub fn down_stacked_store( &self, _expert_idx: usize, ) -> Option<&Arc<dyn MarlinExpertStack<B>>>
Same for down.
Sourcepub fn gemv_gate(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
out: &mut B::Buffer,
top_k: usize,
) -> Result<()>
pub fn gemv_gate( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, out: &mut B::Buffer, top_k: usize, ) -> Result<()>
Gate projection: out_stacked[k] = gate_weight[expert_id[k]] · input,
broadcast input across all top_k slots.
Sourcepub fn gemv_up(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
out: &mut B::Buffer,
top_k: usize,
) -> Result<()>
pub fn gemv_up( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, out: &mut B::Buffer, top_k: usize, ) -> Result<()>
Up projection: same shape as gate, broadcast input.
Sourcepub fn gemv_down(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
out: &mut B::Buffer,
top_k: usize,
expert_intermediate: usize,
) -> Result<()>
pub fn gemv_down( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, out: &mut B::Buffer, top_k: usize, expert_intermediate: usize, ) -> Result<()>
Down projection: per-slot input via in_stride = expert_intermediate.
Caller’s input is the SiLU-mul stacked output (top_k × inter floats).
Sourcepub fn gemv_gate_up_silu_fused(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
out_silu_stacked: &mut B::Buffer,
top_k: usize,
) -> Result<()>
pub fn gemv_gate_up_silu_fused( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, out_silu_stacked: &mut B::Buffer, top_k: usize, ) -> Result<()>
Fused gate + up + SiLU·gate: replaces 3 separate dispatches with 1.
Backend must support the fused path
(B::supports_fused_moe_gate_up_silu()); caller checks first.
Sourcepub fn gemm_gate(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
ids: &B::Buffer,
tpe: &B::Buffer,
dst: &mut B::Buffer,
args_buf: Option<&B::Buffer>,
top_k: usize,
max_per_expert: usize,
tokens: usize,
) -> Result<()>
pub fn gemm_gate( &self, ctx: &mut B::Context, src1: &B::Buffer, ids: &B::Buffer, tpe: &B::Buffer, dst: &mut B::Buffer, args_buf: Option<&B::Buffer>, top_k: usize, max_per_expert: usize, tokens: usize, ) -> Result<()>
Gate prefill GEMM. dst shape: [batch, top_k, expert_inter].
args_buf=Some triggers indirect-grid dispatch.
Sourcepub fn gemm_up(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
ids: &B::Buffer,
tpe: &B::Buffer,
dst: &mut B::Buffer,
args_buf: Option<&B::Buffer>,
top_k: usize,
max_per_expert: usize,
tokens: usize,
) -> Result<()>
pub fn gemm_up( &self, ctx: &mut B::Context, src1: &B::Buffer, ids: &B::Buffer, tpe: &B::Buffer, dst: &mut B::Buffer, args_buf: Option<&B::Buffer>, top_k: usize, max_per_expert: usize, tokens: usize, ) -> Result<()>
Up prefill GEMM. Same shape contract as Self::gemm_gate.
Sourcepub fn gemm_down(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
ids: &B::Buffer,
tpe: &B::Buffer,
dst: &mut B::Buffer,
args_buf: Option<&B::Buffer>,
top_k: usize,
max_per_expert: usize,
tokens: usize,
) -> Result<()>
pub fn gemm_down( &self, ctx: &mut B::Context, src1: &B::Buffer, ids: &B::Buffer, tpe: &B::Buffer, dst: &mut B::Buffer, args_buf: Option<&B::Buffer>, top_k: usize, max_per_expert: usize, tokens: usize, ) -> Result<()>
Down prefill GEMM. dst shape: [batch, top_k, hidden].
ne11=top_k (per-slot src1 read from silu_stacked[batch, top_k, inter]).
Sourcepub fn gemv_gate_batched(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
dst: &mut B::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()>
pub fn gemv_gate_batched( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, dst: &mut B::Buffer, m: usize, top_k: usize, src1_outer_stride: usize, src1_inner_stride: usize, ) -> Result<()>
Gate batched gemv: dst[m * top_k] with broadcast input
(slots within a token share the activation row).
Sourcepub fn gemv_up_batched(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
dst: &mut B::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()>
pub fn gemv_up_batched( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, dst: &mut B::Buffer, m: usize, top_k: usize, src1_outer_stride: usize, src1_inner_stride: usize, ) -> Result<()>
Up batched gemv: same shape as Self::gemv_gate_batched.
Sourcepub fn gemv_down_batched(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
dst: &mut B::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()>
pub fn gemv_down_batched( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, dst: &mut B::Buffer, m: usize, top_k: usize, src1_outer_stride: usize, src1_inner_stride: usize, ) -> Result<()>
Down batched gemv: src1 = silu_stacked[m, top_k, inter] per-slot read.
Sourcepub fn gemv_gate_up_silu_batched_fused(
&self,
ctx: &mut B::Context,
input: &B::Buffer,
ids: &B::Buffer,
silu_out: &mut B::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()>
pub fn gemv_gate_up_silu_batched_fused( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, silu_out: &mut B::Buffer, m: usize, top_k: usize, src1_outer_stride: usize, src1_inner_stride: usize, ) -> Result<()>
Fused batched gate + up + SiLU·gate. Single dispatch over m * top_k
pairs. Caller gates on B::supports_batched_moe_gate_up_silu() first.
Sourcepub fn gemv_gate_offset(
&self,
ctx: &mut B::Context,
src1: &B::Buffer,
src1_offset: usize,
ids: &B::Buffer,
ids_offset: usize,
dst: &mut B::Buffer,
top_k: usize,
src1_stride: usize,
) -> Result<()>
pub fn gemv_gate_offset( &self, ctx: &mut B::Context, src1: &B::Buffer, src1_offset: usize, ids: &B::Buffer, ids_offset: usize, dst: &mut B::Buffer, top_k: usize, src1_stride: usize, ) -> Result<()>
Gate offset gemv. src1_stride=0 → broadcast.
Source§impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B>
impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B>
Sourcepub fn from_dense_stacks(
gate_stack: &[f32],
up_stack: &[f32],
down_stack: &[f32],
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self>
pub fn from_dense_stacks( gate_stack: &[f32], up_stack: &[f32], down_stack: &[f32], num_experts: usize, hidden_size: usize, expert_intermediate: usize, ) -> Result<Self>
Build from raw fp32 stacked tensors (test helper). Caller has
already dequantised and laid out the data:
gate_stack: [num_experts * expert_inter * hidden]
up_stack: [num_experts * expert_inter * hidden]
down_stack: [num_experts * hidden * expert_inter]
Each per-expert slice is row-major in the natural Linear shape.
Sourcepub fn load_from_gguf(
gguf: &GgufFile,
layer_idx: usize,
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self>
pub fn load_from_gguf( gguf: &GgufFile, layer_idx: usize, num_experts: usize, hidden_size: usize, expert_intermediate: usize, ) -> Result<Self>
Load all experts for one MoE layer from a GGUF file. Names follow
the GGUF convention: blk.{layer_idx}.ffn_{gate,up,down}_exps.weight.
The loader picks between two strategies based on the on-disk dtype of the expert tensors:
- Quantised path (Q4_K / Q6_K only): each expert’s
gate || upbecomes a singleQuantLinear<B>(Fused QuantStore — gate + up sharen_cols = hidden), anddownis a plainQuantLinear<B>. Block bytes stay compressed in backend memory; per-call dequant happens insidegemm_quant. - Dense fallback (everything else, e.g. F32 / F16 / Q5_K
until a kernel ships): eager-dequant to fp32 and wrap
DenseLinear<B>. Memory inflates ~7× vs Q4_K_M but the algorithm is correctness-equivalent and this is the path the synthetic-MoE test fixtures need.
The runtime dispatcher (moe_forward<B>) doesn’t see which path
was taken — it just calls Linear::forward per (token, expert).
Sourcepub fn open_and_load(
path: impl AsRef<Path>,
layer_idx: usize,
num_experts: usize,
hidden_size: usize,
expert_intermediate: usize,
) -> Result<Self>
pub fn open_and_load( path: impl AsRef<Path>, layer_idx: usize, num_experts: usize, hidden_size: usize, expert_intermediate: usize, ) -> Result<Self>
Convenience: open a GGUF and load layer layer_idx. The GGUF
stays open inside this call only — for multi-layer loads use
Self::load_from_gguf with a shared GgufFile.
Sourcepub fn num_experts(&self) -> usize
pub fn num_experts(&self) -> usize
num_experts for the layer (consistency check helper).
Returns the per-expert Vec length, OR — when the stacked-only path is in effect (Metal MoE fast path with empty per-expert Vecs) — falls back to a stored count via the stacked variants. In the stacked-only case there’s no Vec to count, so this method is mostly used by tests on the per-expert path.
Auto Trait Implementations§
impl<B> Freeze for ExpertStack<B>
impl<B> !RefUnwindSafe for ExpertStack<B>
impl<B> Send for ExpertStack<B>
impl<B> Sync for ExpertStack<B>
impl<B> Unpin for ExpertStack<B>
impl<B> UnsafeUnpin for ExpertStack<B>
impl<B> !UnwindSafe for ExpertStack<B>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more