burn_optim/optim/
muon.rs

1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8use serde::{Deserialize, Serialize};
9
10use super::{
11    SimpleOptimizer,
12    adaptor::OptimizerAdaptor,
13    decay::WeightDecayConfig,
14    momentum::{Momentum, MomentumConfig, MomentumState},
15};
16use crate::LearningRate;
17
18#[cfg(not(feature = "std"))]
19#[allow(unused_imports)]
20use num_traits::Float as _;
21
22/// Learning rate adjustment method for Muon optimizer.
23///
24/// Muon adjusts the learning rate based on parameter shape to maintain consistent
25/// RMS across rectangular matrices.
26///
27/// # References
28///
29/// - Original: [Muon: An optimizer for hidden layers](https://kellerjordan.github.io/posts/muon/)
30/// - Moonshot: [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)
31#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
32pub enum AdjustLrFn {
33    /// Keller Jordan's original method: `lr * sqrt(max(1, A/B))`
34    ///
35    /// This scales the learning rate based on the aspect ratio of the weight matrix,
36    /// ensuring that tall matrices (more rows than columns) get proportionally larger
37    /// learning rates.
38    ///
39    /// # Example
40    ///
41    /// For a [1024, 512] matrix: `lr * sqrt(1024/512) = lr * 1.414`
42    #[default]
43    Original,
44
45    /// Moonshot's method: `lr * 0.2 * sqrt(max(A, B))`
46    ///
47    /// This method is designed to match AdamW's RMS, allowing Muon to directly reuse
48    /// learning rates and weight decay values tuned for AdamW without retuning.
49    ///
50    /// # Example
51    ///
52    /// For a [1024, 512] matrix: `lr * 0.2 * sqrt(1024) = lr * 6.4`
53    MatchRmsAdamW,
54}
55
56impl AdjustLrFn {
57    /// Calculate the learning rate adjustment ratio for a given parameter shape.
58    ///
59    /// # Arguments
60    ///
61    /// * `shape` - Parameter shape (uses first two dimensions)
62    ///
63    /// # Returns
64    ///
65    /// Adjustment ratio to multiply with the base learning rate
66    fn adjustment_ratio(&self, shape: &[usize]) -> f64 {
67        if shape.len() < 2 {
68            return 1.0;
69        }
70
71        let a = shape[0] as f64;
72        let b = shape[1] as f64;
73
74        match self {
75            Self::Original => {
76                // sqrt(max(1, A/B))
77                let ratio = a / b;
78                ratio.max(1.0).sqrt()
79            }
80            Self::MatchRmsAdamW => {
81                // 0.2 * sqrt(max(A, B))
82                0.2 * a.max(b).sqrt()
83            }
84        }
85    }
86}
87
88/// Muon configuration.
89///
90/// Muon is an optimizer specifically designed for 2D parameters of neural network
91/// hidden layers (weight matrices). Other parameters such as biases and embeddings
92/// should be optimized using a standard method such as AdamW.
93///
94/// # Learning Rate Adjustment
95///
96/// Muon adjusts the learning rate based on parameter shape to maintain consistent
97/// RMS across rectangular matrices. Two methods are available:
98///
99/// - **Original**: Uses `sqrt(max(1, A/B))` where A and B are the first two dimensions.
100///   This is Keller Jordan's method and is the default.
101///
102/// - **MatchRmsAdamW**: Uses `0.2 * sqrt(max(A, B))`. This is Moonshot's method
103///   designed to match AdamW's RMS, allowing direct reuse of AdamW hyperparameters.
104///
105/// # Example
106///
107/// ```ignore
108/// use burn_optim::{MuonConfig, AdjustLrFn};
109///
110/// // Using default (Original) method
111/// let optimizer = MuonConfig::new().init();
112///
113/// // Using MatchRmsAdamW for AdamW-compatible hyperparameters
114/// let optimizer = MuonConfig::new()
115///     .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
116///     .init();
117/// ```
118///
119/// # References
120///
121/// - [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/)
122/// - [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)
123/// - [PyTorch Implementation](https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py)
124/// - [Original Implementation](https://github.com/KellerJordan/Muon)
125#[derive(Config, Debug)]
126pub struct MuonConfig {
127    /// [Weight decay](WeightDecayConfig) config.
128    weight_decay: Option<WeightDecayConfig>,
129
130    /// [Momentum](MomentumConfig) config.
131    ///
132    /// Muon always uses momentum. Default configuration:
133    /// - momentum: 0.95
134    /// - dampening: 0.0
135    /// - nesterov: true
136    #[config(default = "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }")]
137    momentum: MomentumConfig,
138
139    /// Newton-Schulz iteration coefficients (a, b, c).
140    ///
141    /// These coefficients are selected to maximize the slope at zero for the
142    /// quintic iteration. Default values are from Keller Jordan's implementation.
143    #[config(default = "(3.4445, -4.775, 2.0315)")]
144    ns_coefficients: (f32, f32, f32),
145
146    /// Epsilon for numerical stability.
147    #[config(default = 1e-7)]
148    epsilon: f32,
149
150    /// Number of Newton-Schulz iteration steps.
151    #[config(default = 5)]
152    ns_steps: usize,
153
154    /// Learning rate adjustment method.
155    ///
156    /// Controls how the learning rate is adjusted based on parameter shape.
157    /// See [`AdjustLrFn`] for available methods.
158    #[config(default = "AdjustLrFn::Original")]
159    adjust_lr_fn: AdjustLrFn,
160}
161
162impl MuonConfig {
163    /// Initialize Muon optimizer.
164    ///
165    /// # Returns
166    ///
167    /// Returns an optimizer adaptor that can be used to optimize a module.
168    ///
169    /// # Example
170    ///
171    /// ```ignore
172    /// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
173    ///
174    /// // Basic configuration with default (Original) LR adjustment
175    /// let optimizer = MuonConfig::new()
176    ///     .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
177    ///     .init();
178    ///
179    /// // With AdamW-compatible settings using MatchRmsAdamW
180    /// let optimizer = MuonConfig::new()
181    ///     .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
182    ///     .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
183    ///     .init();
184    ///
185    /// // Custom momentum and NS settings
186    /// let optimizer = MuonConfig::new()
187    ///     .with_momentum(MomentumConfig {
188    ///         momentum: 0.9,
189    ///         dampening: 0.1,
190    ///         nesterov: false,
191    ///     })
192    ///     .with_ns_steps(7)
193    ///     .init();
194    /// ```
195    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
196        &self,
197    ) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
198        let momentum = Momentum::new(&self.momentum);
199        let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
200
201        let optim = Muon {
202            momentum,
203            ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
204            weight_decay_penalty,
205            epsilon: self.epsilon,
206            adjust_lr_fn: self.adjust_lr_fn,
207        };
208
209        OptimizerAdaptor::from(optim)
210    }
211}
212
213/// Parameters for Newton-Schulz orthogonalization.
214#[derive(Clone, Copy)]
215struct NewtonSchulzParams {
216    a: f32,
217    b: f32,
218    c: f32,
219    steps: usize,
220}
221
222impl NewtonSchulzParams {
223    fn new(coefficients: (f32, f32, f32), steps: usize) -> Self {
224        Self {
225            a: coefficients.0,
226            b: coefficients.1,
227            c: coefficients.2,
228            steps,
229        }
230    }
231}
232
233/// Muon optimizer.
234///
235/// Muon internally runs standard SGD-momentum, and then performs an orthogonalization
236/// post-processing step, in which each 2D parameter's update is replaced with the
237/// nearest orthogonal matrix. For efficient orthogonalization we use a Newton-Schulz
238/// iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
239///
240/// # Important Notes
241///
242/// 1. **Only for 2D+ parameters**: Muon is designed for weight matrices. Use AdamW
243///    or SGD for biases, embeddings, and layer norms.
244///
245/// 2. **Learning rate adjustment**: Muon automatically adjusts the learning rate based
246///    on parameter shape. See [`AdjustLrFn`] for details.
247///
248/// 3. **Weight decay timing**: Unlike typical optimizers, Muon applies weight decay
249///    AFTER orthogonalization but uses the original (unadjusted) learning rate for it.
250#[derive(Clone)]
251pub struct Muon<B: Backend> {
252    momentum: Momentum<B>,
253    ns_params: NewtonSchulzParams,
254    weight_decay_penalty: Option<f32>,
255    epsilon: f32,
256    adjust_lr_fn: AdjustLrFn,
257}
258
259impl<B: Backend> Muon<B> {
260    /// Adjust learning rate based on parameter shape.
261    ///
262    /// # Arguments
263    ///
264    /// * `lr` - Base learning rate
265    /// * `shape` - Parameter shape (uses first two dimensions)
266    ///
267    /// # Returns
268    ///
269    /// Adjusted learning rate
270    ///
271    /// ```ignore
272    /// // For a [1024, 512] weight matrix with lr=0.01:
273    /// // Original: 0.01 * sqrt(1024/512) = 0.01 * 1.414 = 0.01414
274    /// // MatchRmsAdamW: 0.01 * 0.2 * sqrt(1024) = 0.01 * 0.2 * 32 = 0.064
275    /// ```
276    fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate {
277        lr * self.adjust_lr_fn.adjustment_ratio(shape)
278    }
279
280    /// Perform Newton-Schulz orthogonalization on a gradient tensor.
281    ///
282    /// This computes the zeroth power (orthogonalization) of the input matrix G
283    /// using a quintic Newton-Schulz iteration.
284    ///
285    /// # Algorithm
286    ///
287    /// 1. Transpose if tall matrix (A > B)
288    /// 2. Normalize: X = X / ||X||
289    /// 3. For k steps:
290    ///    - A = X @ X^T
291    ///    - B = b*A + c*A^2
292    ///    - X = a*X + B@X
293    /// 4. Transpose back if needed
294    ///
295    /// # References
296    ///
297    /// - Original: https://github.com/KellerJordan/Muon/blob/master/muon.py
298    /// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py
299    fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {
300        let shape = g.shape();
301        let dim_m2 = shape[D - 2];
302        let dim_m1 = shape[D - 1];
303
304        // Step 1: Transpose if tall matrix (more rows than columns)
305        let (mut x, needs_transpose) = if dim_m2 > dim_m1 {
306            (g.swap_dims(D - 2, D - 1), true)
307        } else {
308            (g, false)
309        };
310
311        // Step 2: Normalize by Frobenius norm
312        // X = X / (||X|| + epsilon)
313        let norm = x
314            .clone()
315            .powf_scalar(2.0)
316            .sum()
317            .sqrt()
318            .clamp_min(self.epsilon)
319            .unsqueeze();
320
321        x = x.div(norm);
322
323        // Step 3: Newton-Schulz iteration
324        // This is the quintic iteration with coefficients (a, b, c)
325        let NewtonSchulzParams { a, b, c, steps } = self.ns_params;
326
327        for _ in 0..steps {
328            // A = X @ X^T
329            let x_t = x.clone().swap_dims(D - 2, D - 1);
330            let a_matrix = x.clone().matmul(x_t);
331
332            // B = b*A + c*A@A
333            let a_squared = a_matrix.clone().matmul(a_matrix.clone());
334            let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c));
335
336            // X = a*X + B@X
337            x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone()));
338        }
339
340        // Step 4: Restore transpose if it was a tall matrix
341        if needs_transpose {
342            x = x.swap_dims(D - 2, D - 1);
343        }
344
345        x
346    }
347}
348
349/// Muon state.
350#[derive(Record, Clone, new)]
351pub struct MuonState<B: Backend, const D: usize> {
352    /// Current momentum state
353    pub momentum: MomentumState<B, D>,
354}
355
356impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
357    type State<const D: usize> = MuonState<B, D>;
358
359    /// Perform a single Muon optimization step.
360    ///
361    /// # Algorithm
362    ///
363    /// 1. Apply momentum to gradient
364    /// 2. Orthogonalize update via Newton-Schulz
365    /// 3. Adjust learning rate based on parameter shape
366    /// 4. Apply weight decay (using original lr)
367    /// 5. Update parameter (using adjusted lr)
368    ///
369    /// # Notes
370    ///
371    /// Unlike typical optimizers, the weight decay and parameter update use
372    /// different learning rates:
373    /// - Weight decay uses the original `lr`
374    /// - Parameter update uses the shape-adjusted `lr`
375    ///
376    /// # Panics
377    /// This function will panic if the input tensors are not 2D.
378    fn step<const D: usize>(
379        &self,
380        lr: LearningRate,
381        tensor: Tensor<B, D>,
382        grad: Tensor<B, D>,
383        state: Option<Self::State<D>>,
384    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
385        assert!(
386            D == 2,
387            "Newton-Schulz iteration requires 2D tensors, got {}D",
388            D
389        );
390
391        // Step 1: Apply momentum
392        let state_momentum = state.map(|s| s.momentum);
393        let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);
394
395        // Step 2: Orthogonalize via Newton-Schulz
396        let update = self.zeropower_via_newtonschulz(grad);
397
398        // Step 3: Adjust learning rate based on parameter shape
399        let adjusted_lr = self.adjust_lr(lr, &tensor.shape());
400
401        // Step 4: Apply weight decay (using ORIGINAL lr, not adjusted)
402        // Muon applies weight decay AFTER orthogonalization
403        let tensor = if let Some(penalty) = self.weight_decay_penalty {
404            let decay_factor = 1.0 - lr * penalty as f64;
405            tensor.mul_scalar(decay_factor)
406        } else {
407            tensor
408        };
409
410        // Step 5: Update parameter (using ADJUSTED lr)
411        let delta = update.mul_scalar(adjusted_lr);
412        let new_state = MuonState::new(new_momentum_state);
413
414        (tensor - delta, Some(new_state))
415    }
416
417    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
418        state.momentum = state.momentum.to_device(device);
419        state
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use crate::TestAutodiffBackend;
427    use crate::{GradientsParams, Optimizer};
428    use burn::module::{Module, Param};
429    use burn::tensor::{Distribution, Tensor, TensorData};
430    use burn_nn::{Linear, LinearConfig, LinearRecord};
431
432    type TestBackend = burn_ndarray::NdArray<f32>;
433
434    const TOLERANCE: f64 = 1e-8;
435
436    fn given_linear_layer_no_bias(weight: TensorData) -> Linear<TestAutodiffBackend> {
437        let device = Default::default();
438        let record = LinearRecord {
439            weight: Param::from_data(weight, &device),
440            bias: None, //No bias for Muon optimizer
441        };
442
443        LinearConfig::new(4, 4)
444            .with_bias(false)
445            .init(&device)
446            .load_record(record)
447    }
448
449    #[test]
450    fn test_adjust_lr_fn_original() {
451        let method = AdjustLrFn::Original;
452
453        // Square matrix [512, 512] -> sqrt(1) = 1.0
454        let ratio = method.adjustment_ratio(&[512, 512]);
455        assert!((ratio - 1.0).abs() < TOLERANCE);
456
457        // Tall matrix [1024, 512] -> sqrt(2) ≈ 1.414
458        let ratio = method.adjustment_ratio(&[1024, 512]);
459        let expected = (2.0f64).sqrt();
460        assert!((ratio - expected).abs() < TOLERANCE);
461
462        // Wide matrix [512, 1024] -> max(1, 0.5) = 1.0
463        let ratio = method.adjustment_ratio(&[512, 1024]);
464        assert!((ratio - 1.0).abs() < TOLERANCE);
465    }
466
467    #[test]
468    fn test_adjust_lr_fn_match_rms_adamw() {
469        let method = AdjustLrFn::MatchRmsAdamW;
470
471        // [1024, 512] -> 0.2 * sqrt(1024) = 6.4
472        let ratio = method.adjustment_ratio(&[1024, 512]);
473        let expected = 0.2 * 1024.0f64.sqrt();
474        assert!((ratio - expected).abs() < TOLERANCE);
475
476        // [512, 512] -> 0.2 * sqrt(512) ≈ 4.525
477        let ratio = method.adjustment_ratio(&[512, 512]);
478        let expected = 0.2 * 512.0f64.sqrt();
479        assert!((ratio - expected).abs() < TOLERANCE);
480    }
481
482    #[test]
483    #[should_panic(expected = "Newton-Schulz iteration requires 2D tensors, got 1D")]
484    fn test_1d_tensor_panics() {
485        let device = Default::default();
486        let config = MuonConfig::new();
487        let optim: Muon<TestBackend> = Muon {
488            momentum: Momentum::new(&config.momentum),
489            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
490            weight_decay_penalty: None,
491            epsilon: config.epsilon,
492            adjust_lr_fn: config.adjust_lr_fn,
493        };
494
495        let tensor_1d = Tensor::<TestBackend, 1>::zeros([512], &device);
496        let grad_1d = Tensor::<TestBackend, 1>::ones([512], &device);
497
498        let _ = optim.step(0.01, tensor_1d, grad_1d, None);
499    }
500
501    #[test]
502    fn test_muon_optimizer_save_load_state() {
503        let device = Default::default();
504        // Use Linear layer WITHOUT bias for Muon optimizer
505        let linear = LinearConfig::new(6, 6)
506            .with_bias(false) // No bias - only 2D weight matrix
507            .init::<TestAutodiffBackend>(&device);
508
509        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
510
511        let mut optimizer =
512            MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
513        let grads = linear.forward(x).backward();
514        let grads = GradientsParams::from_grads(grads, &linear);
515        let _linear = optimizer.step(0.01, linear, grads);
516
517        let state_before = optimizer.to_record();
518        let state_before_copy = optimizer.to_record();
519
520        let optimizer_new =
521            MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
522        let optimizer_loaded = optimizer_new.load_record(state_before_copy);
523        let state_after = optimizer_loaded.to_record();
524
525        assert_eq!(state_before.len(), state_after.len());
526    }
527
528    #[test]
529    fn test_muon_with_weight_decay() {
530        let device = Default::default();
531        // Create Linear layer WITHOUT bias for Muon
532        let linear = given_linear_layer_no_bias(TensorData::from([
533            [1.0, 1.0, 1.0, 1.0],
534            [1.0, 1.0, 1.0, 1.0],
535            [1.0, 1.0, 1.0, 1.0],
536            [1.0, 1.0, 1.0, 1.0],
537        ]));
538
539        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
540            [[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],
541            &device,
542        )
543        .require_grad();
544
545        let mut optimizer = MuonConfig::new()
546            .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
547            .init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
548
549        let grads = linear.forward(x).backward();
550        let grads = GradientsParams::from_grads(grads, &linear);
551        let linear = optimizer.step(0.01, linear, grads);
552
553        let state = linear.into_record();
554        let weight = state.weight.to_data();
555
556        for val in weight.as_slice::<f32>().unwrap() {
557            assert!(
558                *val < 1.0,
559                "Weight should be reduced by weight decay, got {}",
560                val
561            );
562        }
563    }
564
565    #[test]
566    fn test_newton_schulz_orthogonalization() {
567        let device = Default::default();
568        let matrix = Tensor::<TestBackend, 2>::from_floats([[1.0, 0.5], [0.5, 1.0]], &device);
569
570        let config = MuonConfig::new();
571        let muon: Muon<TestBackend> = Muon {
572            momentum: Momentum::new(&config.momentum),
573            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
574            weight_decay_penalty: None,
575            epsilon: config.epsilon,
576            adjust_lr_fn: config.adjust_lr_fn,
577        };
578
579        let orthogonalized = muon.zeropower_via_newtonschulz(matrix);
580        let o_t = orthogonalized.clone().transpose();
581        let product = orthogonalized.matmul(o_t);
582
583        let data = product.into_data();
584        let values = data.as_slice::<f32>().unwrap();
585
586        assert!(
587            (values[0] - 1.0).abs() < 0.1,
588            "Product[0,0] should be ~1.0, got {}",
589            values[0]
590        );
591        assert!(
592            (values[3] - 1.0).abs() < 0.1,
593            "Product[1,1] should be ~1.0, got {}",
594            values[3]
595        );
596    }
597
598    #[test]
599    fn test_tall_matrix_transpose() {
600        // Test that tall matrices (A > B) are transposed during Newton-Schulz iteration
601        // and then transposed back
602        let device = Default::default();
603
604        // Create a tall matrix: [8, 4] (more rows than columns)
605        let tall_matrix = Tensor::<TestBackend, 2>::from_floats(
606            [
607                [1.0, 0.5, 0.3, 0.2],
608                [0.5, 1.0, 0.4, 0.1],
609                [0.3, 0.4, 1.0, 0.5],
610                [0.2, 0.1, 0.5, 1.0],
611                [0.1, 0.2, 0.3, 0.4],
612                [0.4, 0.3, 0.2, 0.1],
613                [0.2, 0.4, 0.1, 0.3],
614                [0.3, 0.1, 0.4, 0.2],
615            ],
616            &device,
617        );
618
619        let config = MuonConfig::new();
620        let muon: Muon<TestBackend> = Muon {
621            momentum: Momentum::new(&config.momentum),
622            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
623            weight_decay_penalty: None,
624            epsilon: config.epsilon,
625            adjust_lr_fn: config.adjust_lr_fn,
626        };
627
628        // Perform Newton-Schulz orthogonalization
629        let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone());
630
631        // Verify shape is preserved (should be transposed internally but returned in original shape)
632        let original_shape = tall_matrix.shape();
633        let result_shape = orthogonalized.shape();
634        assert_eq!(
635            original_shape.dims::<2>(),
636            result_shape.dims::<2>(),
637            "Shape should be preserved: [8, 4]"
638        );
639
640        // Verify output is different from input (orthogonalization happened)
641        let original_data = tall_matrix.into_data();
642        let result_data = orthogonalized.into_data();
643        assert_ne!(
644            original_data.as_slice::<f32>().unwrap(),
645            result_data.as_slice::<f32>().unwrap(),
646            "Orthogonalized matrix should differ from input"
647        );
648
649        // For comparison, test a wide matrix [4, 8] should NOT be transposed
650        let wide_matrix = Tensor::<TestBackend, 2>::from_floats(
651            [
652                [1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3],
653                [0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1],
654                [0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4],
655                [0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2],
656            ],
657            &device,
658        );
659
660        let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone());
661
662        // Verify wide matrix shape is also preserved
663        let wide_original_shape = wide_matrix.shape();
664        let wide_result_shape = orthogonalized_wide.shape();
665        assert_eq!(
666            wide_original_shape.dims::<2>(),
667            wide_result_shape.dims::<2>(),
668            "Wide matrix shape should be preserved: [4, 8]"
669        );
670    }
671
672    #[test]
673    fn test_zero_gradient() {
674        // Test that Muon handles zero gradients gracefully
675        let device = Default::default();
676
677        let tensor = Tensor::<TestBackend, 2>::from_floats(
678            [
679                [1.0, 0.5, 0.3, 0.2],
680                [0.5, 1.0, 0.4, 0.1],
681                [0.3, 0.4, 1.0, 0.5],
682                [0.2, 0.1, 0.5, 1.0],
683            ],
684            &device,
685        );
686
687        // Zero gradient - all zeros
688        let zero_grad = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
689
690        let config = MuonConfig::new();
691        let muon: Muon<TestBackend> = Muon {
692            momentum: Momentum::new(&config.momentum),
693            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
694            weight_decay_penalty: None,
695            epsilon: config.epsilon,
696            adjust_lr_fn: config.adjust_lr_fn,
697        };
698
699        // Should not panic or produce NaN
700        let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None);
701
702        // Verify state was created
703        assert!(state.is_some());
704
705        // With zero gradient and no weight decay, tensor should remain unchanged
706        let original_data = tensor.into_data();
707        let updated_data = updated_tensor.clone().into_data();
708
709        let original_vals = original_data.as_slice::<f32>().unwrap();
710        let updated_vals = updated_data.as_slice::<f32>().unwrap();
711
712        for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) {
713            assert!(
714                (orig - upd).abs() < 1e-6,
715                "With zero gradient, tensor should remain unchanged (or very close)"
716            );
717        }
718
719        // Verify no NaN values
720        for val in updated_vals {
721            assert!(
722                !val.is_nan(),
723                "Result should not contain NaN values with zero gradient"
724            );
725        }
726
727        // Test with weight decay - should still work
728        let muon_with_decay: Muon<TestBackend> = Muon {
729            momentum: Momentum::new(&config.momentum),
730            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
731            weight_decay_penalty: Some(0.01),
732            epsilon: config.epsilon,
733            adjust_lr_fn: config.adjust_lr_fn,
734        };
735
736        let tensor2 = Tensor::<TestBackend, 2>::from_floats(
737            [
738                [1.0, 0.5, 0.3, 0.2],
739                [0.5, 1.0, 0.4, 0.1],
740                [0.3, 0.4, 1.0, 0.5],
741                [0.2, 0.1, 0.5, 1.0],
742            ],
743            &device,
744        );
745        let zero_grad2 = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
746
747        let (updated_tensor_decay, _) =
748            muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None);
749
750        // With zero gradient but with weight decay, tensor should be slightly reduced
751        let updated_decay_data = updated_tensor_decay.into_data();
752        let updated_decay_vals = updated_decay_data.as_slice::<f32>().unwrap();
753
754        for val in updated_decay_vals {
755            assert!(
756                !val.is_nan(),
757                "Result should not contain NaN with zero gradient and weight decay"
758            );
759        }
760
761        // With weight decay, values should be slightly smaller than original
762        let original_vals2 = tensor2.into_data().as_slice::<f32>().unwrap().to_vec();
763        for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) {
764            if orig.abs() > 1e-6 {
765                // Non-zero values should be reduced by weight decay
766                assert!(
767                    upd.abs() < orig.abs(),
768                    "Weight decay should reduce magnitude: original={}, updated={}",
769                    orig,
770                    upd
771                );
772            }
773        }
774    }
775}