pub struct MuonConfig { /* private fields */ }Expand description
Muon configuration.
Muon is an optimizer specifically designed for 2D parameters of neural network hidden layers (weight matrices). Other parameters such as biases and embeddings should be optimized using a standard method such as AdamW.
§Learning Rate Adjustment
Muon adjusts the learning rate based on parameter shape to maintain consistent RMS across rectangular matrices. Two methods are available:
-
Original: Uses
sqrt(max(1, A/B))where A and B are the first two dimensions. This is Keller Jordan’s method and is the default. -
MatchRmsAdamW: Uses
0.2 * sqrt(max(A, B)). This is Moonshot’s method designed to match AdamW’s RMS, allowing direct reuse of AdamW hyperparameters.
§Example
use burn_optim::{MuonConfig, AdjustLrFn};
// Using default (Original) method
let optimizer = MuonConfig::new().init();
// Using MatchRmsAdamW for AdamW-compatible hyperparameters
let optimizer = MuonConfig::new()
.with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
.init();§References
Implementations§
Source§impl MuonConfig
impl MuonConfig
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new instance of the config.
§Arguments
§Optional Arguments
§weight_decay
Weight decay config.
- Defaults to
None
§Default Arguments
§momentum
Momentum config.
Muon always uses momentum. Default configuration:
- momentum: 0.95
- dampening: 0.0
- nesterov: true
- Defaults to
"MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }"
§ns_coefficients
Newton-Schulz iteration coefficients (a, b, c).
These coefficients are selected to maximize the slope at zero for the quintic iteration. Default values are from Keller Jordan’s implementation.
- Defaults to
"(3.4445, -4.775, 2.0315)"
§epsilon
Epsilon for numerical stability.
- Defaults to
1e-7
§ns_steps
Number of Newton-Schulz iteration steps.
- Defaults to
5
§adjust_lr_fn
Learning rate adjustment method.
Controls how the learning rate is adjusted based on parameter shape.
See AdjustLrFn for available methods.
- Defaults to
"AdjustLrFn::Original"
Source§impl MuonConfig
impl MuonConfig
Sourcepub fn with_momentum(self, momentum: MomentumConfig) -> Self
pub fn with_momentum(self, momentum: MomentumConfig) -> Self
Sourcepub fn with_ns_coefficients(self, ns_coefficients: (f32, f32, f32)) -> Self
pub fn with_ns_coefficients(self, ns_coefficients: (f32, f32, f32)) -> Self
Sets the value for the field ns_coefficients.
Newton-Schulz iteration coefficients (a, b, c).
These coefficients are selected to maximize the slope at zero for the quintic iteration. Default values are from Keller Jordan’s implementation.
- Defaults to
"(3.4445, -4.775, 2.0315)"
Sourcepub fn with_epsilon(self, epsilon: f32) -> Self
pub fn with_epsilon(self, epsilon: f32) -> Self
Sourcepub fn with_ns_steps(self, ns_steps: usize) -> Self
pub fn with_ns_steps(self, ns_steps: usize) -> Self
Sourcepub fn with_adjust_lr_fn(self, adjust_lr_fn: AdjustLrFn) -> Self
pub fn with_adjust_lr_fn(self, adjust_lr_fn: AdjustLrFn) -> Self
Sets the value for the field adjust_lr_fn.
Learning rate adjustment method.
Controls how the learning rate is adjusted based on parameter shape.
See AdjustLrFn for available methods.
- Defaults to
"AdjustLrFn::Original"
Sourcepub fn with_weight_decay(self, weight_decay: Option<WeightDecayConfig>) -> Self
pub fn with_weight_decay(self, weight_decay: Option<WeightDecayConfig>) -> Self
Source§impl MuonConfig
impl MuonConfig
Sourcepub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B>
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>( &self, ) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B>
Initialize Muon optimizer.
§Returns
Returns an optimizer adaptor that can be used to optimize a module.
§Example
use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
// Basic configuration with default (Original) LR adjustment
let optimizer = MuonConfig::new()
.with_weight_decay(Some(WeightDecayConfig::new(0.01)))
.init();
// With AdamW-compatible settings using MatchRmsAdamW
let optimizer = MuonConfig::new()
.with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
.with_weight_decay(Some(WeightDecayConfig::new(0.1)))
.init();
// Custom momentum and NS settings
let optimizer = MuonConfig::new()
.with_momentum(MomentumConfig {
momentum: 0.9,
dampening: 0.1,
nesterov: false,
})
.with_ns_steps(7)
.init();Trait Implementations§
Source§impl Clone for MuonConfig
impl Clone for MuonConfig
Source§impl Config for MuonConfig
impl Config for MuonConfig
Source§fn save<P>(&self, file: P) -> Result<(), Error>
fn save<P>(&self, file: P) -> Result<(), Error>
std only.Source§fn load<P>(file: P) -> Result<Self, ConfigError>
fn load<P>(file: P) -> Result<Self, ConfigError>
std only.