Skip to main content

ExpertStack

Struct ExpertStack 

Source
pub struct ExpertStack<B: QuantLlmBackend + BackendMoeFused> {
    pub gate_up: Vec<Box<dyn Linear<B>>>,
    pub down: Vec<Box<dyn Linear<B>>>,
    pub gate_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
    pub up_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
    pub down_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>,
    pub gate_up_marlin_stack: Option<Arc<dyn MarlinExpertStack<B>>>,
    pub down_marlin_stack: Option<Arc<dyn MarlinExpertStack<B>>>,
}
Expand description

Per-layer expert weights, materialised as [num_experts]-long vectors of Box<dyn Linear<B>>. Each entry runs the corresponding expert’s fused [gate; up] projection or its down projection.

B::Buffer is hidden behind Linear<B> so this struct is generic over backend. Production (Qwen3MoeModel::forward) dispatches through the generic moe_forward<B> (this file, line ~960) and moe_forward_bucketed<B>; the CPU-only moe_forward_cpu is the reference path used by parity tests + Qwen3MoeLayer::forward_cpu.

Fields§

§gate_up: Vec<Box<dyn Linear<B>>>

Fused [gate; up] projection per expert. Output shape per token: [2 * expert_intermediate] — the lower half is gate, upper is up.

§down: Vec<Box<dyn Linear<B>>>

down projection per expert. Output shape per token: [hidden_size].

§gate_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>

Stacked-experts representation for backends that have a batched MoE indirect-dispatch kernel (Metal gemv_q4kw_moe_id_f32 / gemv_q6kw_moe_id_f32). Holds all experts for one matmul role behind a StackedExpertGgufLinear<B> (typically backed by a single GPU buffer with byte stride between expert slabs), so a single dispatch can cover all selected (token, expert) pairs at decode m=1.

None on backends without the kernel (CPU, CUDA-without-MoE-kernel) and on quant flavours that don’t have a stacked path yet — callers fall back to the per-expert gate_up / down Linears in those cases.

§up_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>§down_stacked: Option<Box<dyn StackedExpertGgufLinear<B>>>§gate_up_marlin_stack: Option<Arc<dyn MarlinExpertStack<B>>>

Stacked Marlin GPTQ expert tiles for the bucketed CUDA path. When both are Some, moe_forward_bucketed dispatches expert GEMMs through trait-object methods (store.gemm_phase_* / store.zero_workspace). None on CPU / Metal / GGUF.

Phase C step 3: replaces Option<Arc<B::GptqStore>> with a Box<dyn MarlinExpertStack<B>> trait object — kills the type GptqStore leak through the model layer.

§down_marlin_stack: Option<Arc<dyn MarlinExpertStack<B>>>

Implementations§

Source§

impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B>

Source

pub fn gate_up_stacked_store( &self, _expert_idx: usize, ) -> Option<&Arc<dyn MarlinExpertStack<B>>>

Returns the shared stacked Marlin expert tile for gate_up if loaded via the bucketed/Marlin path. Used by moe_forward_bucketed.

Source

pub fn down_stacked_store( &self, _expert_idx: usize, ) -> Option<&Arc<dyn MarlinExpertStack<B>>>

Same for down.

Source

pub fn gemv_gate( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, out: &mut B::Buffer, top_k: usize, ) -> Result<()>

Gate projection: out_stacked[k] = gate_weight[expert_id[k]] · input, broadcast input across all top_k slots.

Source

pub fn gemv_up( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, out: &mut B::Buffer, top_k: usize, ) -> Result<()>

Up projection: same shape as gate, broadcast input.

Source

pub fn gemv_down( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, out: &mut B::Buffer, top_k: usize, expert_intermediate: usize, ) -> Result<()>

Down projection: per-slot input via in_stride = expert_intermediate. Caller’s input is the SiLU-mul stacked output (top_k × inter floats).

Source

pub fn gemv_gate_up_silu_fused( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, out_silu_stacked: &mut B::Buffer, top_k: usize, ) -> Result<()>

Fused gate + up + SiLU·gate: replaces 3 separate dispatches with 1. Backend must support the fused path (B::supports_fused_moe_gate_up_silu()); caller checks first.

Source

pub fn gemm_gate( &self, ctx: &mut B::Context, src1: &B::Buffer, ids: &B::Buffer, tpe: &B::Buffer, dst: &mut B::Buffer, args_buf: Option<&B::Buffer>, top_k: usize, max_per_expert: usize, tokens: usize, ) -> Result<()>

Gate prefill GEMM. dst shape: [batch, top_k, expert_inter]. args_buf=Some triggers indirect-grid dispatch.

Source

pub fn gemm_up( &self, ctx: &mut B::Context, src1: &B::Buffer, ids: &B::Buffer, tpe: &B::Buffer, dst: &mut B::Buffer, args_buf: Option<&B::Buffer>, top_k: usize, max_per_expert: usize, tokens: usize, ) -> Result<()>

Up prefill GEMM. Same shape contract as Self::gemm_gate.

Source

pub fn gemm_down( &self, ctx: &mut B::Context, src1: &B::Buffer, ids: &B::Buffer, tpe: &B::Buffer, dst: &mut B::Buffer, args_buf: Option<&B::Buffer>, top_k: usize, max_per_expert: usize, tokens: usize, ) -> Result<()>

Down prefill GEMM. dst shape: [batch, top_k, hidden]. ne11=top_k (per-slot src1 read from silu_stacked[batch, top_k, inter]).

Source

pub fn gemv_gate_batched( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, dst: &mut B::Buffer, m: usize, top_k: usize, src1_outer_stride: usize, src1_inner_stride: usize, ) -> Result<()>

Gate batched gemv: dst[m * top_k] with broadcast input (slots within a token share the activation row).

Source

pub fn gemv_up_batched( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, dst: &mut B::Buffer, m: usize, top_k: usize, src1_outer_stride: usize, src1_inner_stride: usize, ) -> Result<()>

Up batched gemv: same shape as Self::gemv_gate_batched.

Source

pub fn gemv_down_batched( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, dst: &mut B::Buffer, m: usize, top_k: usize, src1_outer_stride: usize, src1_inner_stride: usize, ) -> Result<()>

Down batched gemv: src1 = silu_stacked[m, top_k, inter] per-slot read.

Source

pub fn gemv_gate_up_silu_batched_fused( &self, ctx: &mut B::Context, input: &B::Buffer, ids: &B::Buffer, silu_out: &mut B::Buffer, m: usize, top_k: usize, src1_outer_stride: usize, src1_inner_stride: usize, ) -> Result<()>

Fused batched gate + up + SiLU·gate. Single dispatch over m * top_k pairs. Caller gates on B::supports_batched_moe_gate_up_silu() first.

Source

pub fn gemv_gate_offset( &self, ctx: &mut B::Context, src1: &B::Buffer, src1_offset: usize, ids: &B::Buffer, ids_offset: usize, dst: &mut B::Buffer, top_k: usize, src1_stride: usize, ) -> Result<()>

Gate offset gemv. src1_stride=0 → broadcast.

Source

pub fn gemv_up_offset( &self, ctx: &mut B::Context, src1: &B::Buffer, src1_offset: usize, ids: &B::Buffer, ids_offset: usize, dst: &mut B::Buffer, top_k: usize, src1_stride: usize, ) -> Result<()>

Up offset gemv.

Source

pub fn gemv_down_offset( &self, ctx: &mut B::Context, src1: &B::Buffer, src1_offset: usize, ids: &B::Buffer, ids_offset: usize, dst: &mut B::Buffer, top_k: usize, src1_stride: usize, ) -> Result<()>

Down offset gemv.

Source§

impl<B: QuantLlmBackend + BackendMoeFused> ExpertStack<B>

Source

pub fn from_dense_stacks( gate_stack: &[f32], up_stack: &[f32], down_stack: &[f32], num_experts: usize, hidden_size: usize, expert_intermediate: usize, ) -> Result<Self>

Build from raw fp32 stacked tensors (test helper). Caller has already dequantised and laid out the data: gate_stack: [num_experts * expert_inter * hidden] up_stack: [num_experts * expert_inter * hidden] down_stack: [num_experts * hidden * expert_inter] Each per-expert slice is row-major in the natural Linear shape.

Source

pub fn load_from_gguf( gguf: &GgufFile, layer_idx: usize, num_experts: usize, hidden_size: usize, expert_intermediate: usize, ) -> Result<Self>

Load all experts for one MoE layer from a GGUF file. Names follow the GGUF convention: blk.{layer_idx}.ffn_{gate,up,down}_exps.weight.

The loader picks between two strategies based on the on-disk dtype of the expert tensors:

  • Quantised path (Q4_K / Q6_K only): each expert’s gate || up becomes a single QuantLinear<B> (Fused QuantStore — gate + up share n_cols = hidden), and down is a plain QuantLinear<B>. Block bytes stay compressed in backend memory; per-call dequant happens inside gemm_quant.
  • Dense fallback (everything else, e.g. F32 / F16 / Q5_K until a kernel ships): eager-dequant to fp32 and wrap DenseLinear<B>. Memory inflates ~7× vs Q4_K_M but the algorithm is correctness-equivalent and this is the path the synthetic-MoE test fixtures need.

The runtime dispatcher (moe_forward<B>) doesn’t see which path was taken — it just calls Linear::forward per (token, expert).

Source

pub fn open_and_load( path: impl AsRef<Path>, layer_idx: usize, num_experts: usize, hidden_size: usize, expert_intermediate: usize, ) -> Result<Self>

Convenience: open a GGUF and load layer layer_idx. The GGUF stays open inside this call only — for multi-layer loads use Self::load_from_gguf with a shared GgufFile.

Source

pub fn num_experts(&self) -> usize

num_experts for the layer (consistency check helper).

Returns the per-expert Vec length, OR — when the stacked-only path is in effect (Metal MoE fast path with empty per-expert Vecs) — falls back to a stored count via the stacked variants. In the stacked-only case there’s no Vec to count, so this method is mostly used by tests on the per-expert path.

Auto Trait Implementations§

§

impl<B> Freeze for ExpertStack<B>

§

impl<B> !RefUnwindSafe for ExpertStack<B>

§

impl<B> Send for ExpertStack<B>

§

impl<B> Sync for ExpertStack<B>

§

impl<B> Unpin for ExpertStack<B>

§

impl<B> UnsafeUnpin for ExpertStack<B>

§

impl<B> !UnwindSafe for ExpertStack<B>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<F, T> IntoSample<T> for F
where T: FromSample<F>,

Source§

fn into_sample(self) -> T

Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> PolicyExt for T
where T: ?Sized,

Source§

fn and<P, B, E>(self, other: P) -> And<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow only if self and other return Action::Follow. Read more
Source§

fn or<P, B, E>(self, other: P) -> Or<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow if either self or other returns Action::Follow. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> ErasedDestructor for T
where T: 'static,