pub struct MoeLayer<R: Runtime> { /* private fields */ }Expand description
Mixture of Experts layer.
Routes tokens to top-k experts, computes expert outputs, and returns the weighted combination.
All computation stays on-device — no GPU-CPU transfers.
Implementations§
Source§impl<R: Runtime> MoeLayer<R>
impl<R: Runtime> MoeLayer<R>
pub fn new( router: MoeRouter<R>, experts: Vec<Expert<R>>, shared_expert: Option<Expert<R>>, ) -> Self
pub fn router(&self) -> &MoeRouter<R>
pub fn experts(&self) -> &[Expert<R>]
Sourcepub fn forward<C>(&self, client: &C, x: &Var<R>) -> Result<MoeOutput<R>>where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + ReduceOps<R> + ShapeOps<R> + ActivationOps<R> + SortingOps<R> + IndexingOps<R> + CompareOps<R>,
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + ActivationOps<R> + ReduceOps<R> + ShapeOps<R>,
pub fn forward<C>(&self, client: &C, x: &Var<R>) -> Result<MoeOutput<R>>where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + ReduceOps<R> + ShapeOps<R> + ActivationOps<R> + SortingOps<R> + IndexingOps<R> + CompareOps<R>,
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + ActivationOps<R> + ReduceOps<R> + ShapeOps<R>,
Forward pass with auxiliary loss.
Input: [num_tokens, hidden_size]
Returns: MoeOutput with output tensor and aux_loss
Strategy: iterate over experts (not tokens). For each expert, compute output for all tokens, then mask-and-weight by routing decisions. All ops stay on-device.
Auto Trait Implementations§
impl<R> Freeze for MoeLayer<R>
impl<R> !RefUnwindSafe for MoeLayer<R>
impl<R> Send for MoeLayer<R>
impl<R> Sync for MoeLayer<R>
impl<R> Unpin for MoeLayer<R>
impl<R> UnsafeUnpin for MoeLayer<R>
impl<R> !UnwindSafe for MoeLayer<R>
Blanket Implementations§
Source§impl<T> ArchivePointee for T
impl<T> ArchivePointee for T
Source§type ArchivedMetadata = ()
type ArchivedMetadata = ()
The archived version of the pointer metadata for this type.
Source§fn pointer_metadata(
_: &<T as ArchivePointee>::ArchivedMetadata,
) -> <T as Pointee>::Metadata
fn pointer_metadata( _: &<T as ArchivePointee>::ArchivedMetadata, ) -> <T as Pointee>::Metadata
Converts some archived metadata to the pointer metadata for itself.
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> LayoutRaw for T
impl<T> LayoutRaw for T
Source§fn layout_raw(_: <T as Pointee>::Metadata) -> Result<Layout, LayoutError>
fn layout_raw(_: <T as Pointee>::Metadata) -> Result<Layout, LayoutError>
Returns the layout of the type.
Source§impl<T, N1, N2> Niching<NichedOption<T, N1>> for N2
impl<T, N1, N2> Niching<NichedOption<T, N1>> for N2
Source§unsafe fn is_niched(niched: *const NichedOption<T, N1>) -> bool
unsafe fn is_niched(niched: *const NichedOption<T, N1>) -> bool
Returns whether the given value has been niched. Read more
Source§fn resolve_niched(out: Place<NichedOption<T, N1>>)
fn resolve_niched(out: Place<NichedOption<T, N1>>)
Writes data to
out indicating that a T is niched.