Skip to main content

MuonConfig

Struct MuonConfig 

Source
pub struct MuonConfig { /* private fields */ }
Expand description

Muon configuration.

Muon is an optimizer specifically designed for 2D parameters of neural network hidden layers (weight matrices). Other parameters such as biases and embeddings should be optimized using a standard method such as AdamW.

§Learning Rate Adjustment

Muon adjusts the learning rate based on parameter shape to maintain consistent RMS across rectangular matrices. Two methods are available:

  • Original: Uses sqrt(max(1, A/B)) where A and B are the first two dimensions. This is Keller Jordan’s method and is the default.

  • MatchRmsAdamW: Uses 0.2 * sqrt(max(A, B)). This is Moonshot’s method designed to match AdamW’s RMS, allowing direct reuse of AdamW hyperparameters.

§Example

use burn_optim::{MuonConfig, AdjustLrFn};

// Using default (Original) method
let optimizer = MuonConfig::new().init();

// Using MatchRmsAdamW for AdamW-compatible hyperparameters
let optimizer = MuonConfig::new()
    .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
    .init();

§References

Implementations§

Source§

impl MuonConfig

Source

pub fn new() -> Self

Create a new instance of the config.

§Arguments
§Optional Arguments
§weight_decay

Weight decay config.

  • Defaults to None
§Default Arguments
§momentum

Momentum config.

Muon always uses momentum. Default configuration:

  • momentum: 0.95
  • dampening: 0.0
  • nesterov: true
  • Defaults to "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }"
§ns_coefficients

Newton-Schulz iteration coefficients (a, b, c).

These coefficients are selected to maximize the slope at zero for the quintic iteration. Default values are from Keller Jordan’s implementation.

  • Defaults to "(3.4445, -4.775, 2.0315)"
§epsilon

Epsilon for numerical stability.

  • Defaults to 1e-7
§ns_steps

Number of Newton-Schulz iteration steps.

  • Defaults to 5
§adjust_lr_fn

Learning rate adjustment method.

Controls how the learning rate is adjusted based on parameter shape. See AdjustLrFn for available methods.

  • Defaults to "AdjustLrFn::Original"
Source§

impl MuonConfig

Source

pub fn with_momentum(self, momentum: MomentumConfig) -> Self

Sets the value for the field momentum.

Momentum config.

Muon always uses momentum. Default configuration:

  • momentum: 0.95
  • dampening: 0.0
  • nesterov: true
  • Defaults to "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }"
Source

pub fn with_ns_coefficients(self, ns_coefficients: (f32, f32, f32)) -> Self

Sets the value for the field ns_coefficients.

Newton-Schulz iteration coefficients (a, b, c).

These coefficients are selected to maximize the slope at zero for the quintic iteration. Default values are from Keller Jordan’s implementation.

  • Defaults to "(3.4445, -4.775, 2.0315)"
Source

pub fn with_epsilon(self, epsilon: f32) -> Self

Sets the value for the field epsilon.

Epsilon for numerical stability.

  • Defaults to 1e-7
Source

pub fn with_ns_steps(self, ns_steps: usize) -> Self

Sets the value for the field ns_steps.

Number of Newton-Schulz iteration steps.

  • Defaults to 5
Source

pub fn with_adjust_lr_fn(self, adjust_lr_fn: AdjustLrFn) -> Self

Sets the value for the field adjust_lr_fn.

Learning rate adjustment method.

Controls how the learning rate is adjusted based on parameter shape. See AdjustLrFn for available methods.

  • Defaults to "AdjustLrFn::Original"
Source

pub fn with_weight_decay(self, weight_decay: Option<WeightDecayConfig>) -> Self

Sets the value for the field weight_decay.

Weight decay config.

  • Defaults to None
Source§

impl MuonConfig

Source

pub fn build<B: Backend>(&self) -> Muon<B>

Build a Muon from the config.

Source

pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>( &self, ) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B>

Initialize Muon optimizer.

§Returns

Returns an optimizer adaptor that can be used to optimize a module.

§Example
use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};

// Basic configuration with default (Original) LR adjustment
let optimizer = MuonConfig::new()
    .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
    .init();

// With AdamW-compatible settings using MatchRmsAdamW
let optimizer = MuonConfig::new()
    .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
    .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
    .init();

// Custom momentum and NS settings
let optimizer = MuonConfig::new()
    .with_momentum(MomentumConfig {
        momentum: 0.9,
        dampening: 0.1,
        nesterov: false,
    })
    .with_ns_steps(7)
    .init();

Trait Implementations§

Source§

impl Clone for MuonConfig

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Config for MuonConfig

Source§

fn save<P>(&self, file: P) -> Result<(), Error>
where P: AsRef<Path>,

Available on crate feature std only.
Saves the configuration to a file. Read more
Source§

fn load<P>(file: P) -> Result<Self, ConfigError>
where P: AsRef<Path>,

Available on crate feature std only.
Loads the configuration from a file. Read more
Source§

fn load_binary(data: &[u8]) -> Result<Self, ConfigError>

Loads the configuration from a binary buffer. Read more
Source§

impl Debug for MuonConfig

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<'de> Deserialize<'de> for MuonConfig

Source§

fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>,

Deserialize this value from the given Serde deserializer. Read more
Source§

impl Display for MuonConfig

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Serialize for MuonConfig

Source§

fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer,

Serialize this value into the given Serde serializer. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> ToString for T
where T: Display + ?Sized,

Source§

fn to_string(&self) -> String

Converts the given value to a String. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<T> DeserializeOwned for T
where T: for<'de> Deserialize<'de>,