Muon

Struct Muon 

Source
pub struct Muon<B: Backend> { /* private fields */ }
Expand description

Muon optimizer.

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter’s update is replaced with the nearest orthogonal matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

§Important Notes

  1. Only for 2D+ parameters: Muon is designed for weight matrices. Use AdamW or SGD for biases, embeddings, and layer norms.

  2. Learning rate adjustment: Muon automatically adjusts the learning rate based on parameter shape. See AdjustLrFn for details.

  3. Weight decay timing: Unlike typical optimizers, Muon applies weight decay AFTER orthogonalization but uses the original (unadjusted) learning rate for it.

Trait Implementations§

Source§

impl<B: Clone + Backend> Clone for Muon<B>

Source§

fn clone(&self) -> Muon<B>

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<B: Backend> SimpleOptimizer<B> for Muon<B>

Source§

fn step<const D: usize>( &self, lr: LearningRate, tensor: Tensor<B, D>, grad: Tensor<B, D>, state: Option<Self::State<D>>, ) -> (Tensor<B, D>, Option<Self::State<D>>)

Perform a single Muon optimization step.

§Algorithm
  1. Apply momentum to gradient
  2. Orthogonalize update via Newton-Schulz
  3. Adjust learning rate based on parameter shape
  4. Apply weight decay (using original lr)
  5. Update parameter (using adjusted lr)
§Notes

Unlike typical optimizers, the weight decay and parameter update use different learning rates:

  • Weight decay uses the original lr
  • Parameter update uses the shape-adjusted lr
§Panics

This function will panic if the input tensors are not 2D.

Source§

type State<const D: usize> = MuonState<B, D>

The state of the optimizer. It also implements record, so that it can be saved.
Source§

fn to_device<const D: usize>( state: Self::State<D>, device: &Device<B>, ) -> Self::State<D>

Change the device of the state. Read more

Auto Trait Implementations§

§

impl<B> Freeze for Muon<B>
where <B as Backend>::FloatElem: Freeze,

§

impl<B> RefUnwindSafe for Muon<B>

§

impl<B> Send for Muon<B>

§

impl<B> Sync for Muon<B>

§

impl<B> Unpin for Muon<B>
where <B as Backend>::FloatElem: Unpin,

§

impl<B> UnwindSafe for Muon<B>
where <B as Backend>::FloatElem: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V