pub struct MoeForwardParams<'a, B: QuantLlmBackend + BackendMoeFused> {Show 17 fields
pub ctx: &'a mut B::Context,
pub x: &'a B::Buffer,
pub router_logits: &'a B::Buffer,
pub out: &'a mut B::Buffer,
pub batch: usize,
pub hidden_size: usize,
pub expert_intermediate: usize,
pub num_experts: usize,
pub top_k: usize,
pub norm_topk_prob: bool,
pub experts: &'a ExpertStack<B>,
pub x_single: &'a mut B::Buffer,
pub acc_buf: &'a mut B::Buffer,
pub gate_up_buf: &'a mut B::Buffer,
pub silu_buf: &'a mut B::Buffer,
pub down_buf: &'a mut B::Buffer,
pub zero_hidden: &'a B::Buffer,
}Expand description
Backend-generic MoE forward.
Equivalent of moe_forward_cpu but parameterised on B: Backend
so Metal / CUDA paths can dispatch the same per-(token, expert) loop
using their own kernels for the gemv + silu + scaled-add primitives.
The caller pre-supplies all scratch buffers — this function does no
allocation, which matters because it’s invoked from inside the
transformer’s forward_layer where allocation during graph capture
(CUDA) would corrupt the captured graph.
Buffer contract (lengths, sized at scratch alloc time):
x:[batch * hidden]post-RMSNorm activationsrouter_logits:[batch * num_experts]raw router outputout:[batch * hidden]— caller is responsible for zeroing this before the call (we accumulate, not assign)x_single:[hidden]per-token input sliceacc_buf:[hidden]per-token output accumulator (kept separate fromx_singleso the gate_up gemv can consumex_singlerepeatedly across the top_k loop without an inter-pair restore)gate_up_buf:[2 * expert_inter]per-(token, expert) gemv outsilu_buf:[expert_inter]down_buf:[hidden]per-(token, expert) accumulate src
Routing (softmax + top-K + optional renorm) runs host-side using
B::to_vec(router_logits, …) — the routing computation is small
(batch * num_experts floats) and the top-K is a sort, both of
which dwarf in cost any plausible host↔device transfer.
Per-pair dispatch budget (m=1, Metal):
gate_up Fused gemv (2 parts) + silu + down gemv + scaled_add
= 5 dispatches/pair. Plus 2 copy_slice/token (load x_single,
write acc_buf back to out[b]). With top_k=8 and 48 layers, that’s
8×5 + 2 = 42 dispatches/layer × 48 ≈ 2k/token (vs. ~3.5k in the
pre-PR scheme that round-tripped through out per pair).
Fields§
§ctx: &'a mut B::Context§x: &'a B::Buffer§router_logits: &'a B::Buffer§out: &'a mut B::Buffer§batch: usize§expert_intermediate: usize§num_experts: usize§top_k: usize§norm_topk_prob: bool§experts: &'a ExpertStack<B>§x_single: &'a mut B::Buffer§acc_buf: &'a mut B::Buffer§gate_up_buf: &'a mut B::Buffer§silu_buf: &'a mut B::Buffer§down_buf: &'a mut B::BufferAuto Trait Implementations§
impl<'a, B> Freeze for MoeForwardParams<'a, B>
impl<'a, B> !RefUnwindSafe for MoeForwardParams<'a, B>
impl<'a, B> Send for MoeForwardParams<'a, B>
impl<'a, B> Sync for MoeForwardParams<'a, B>
impl<'a, B> Unpin for MoeForwardParams<'a, B>
impl<'a, B> UnsafeUnpin for MoeForwardParams<'a, B>
impl<'a, B> !UnwindSafe for MoeForwardParams<'a, B>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more