Skip to main content

burn_optim/optim/
muon.rs

1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8use serde::{Deserialize, Serialize};
9
10use super::{
11    SimpleOptimizer,
12    adaptor::OptimizerAdaptor,
13    decay::WeightDecayConfig,
14    momentum::{Momentum, MomentumConfig, MomentumState},
15};
16use crate::LearningRate;
17
18#[cfg(not(feature = "std"))]
19#[allow(unused_imports)]
20use num_traits::Float as _;
21
22/// Learning rate adjustment method for Muon optimizer.
23///
24/// Muon adjusts the learning rate based on parameter shape to maintain consistent
25/// RMS across rectangular matrices.
26///
27/// # References
28///
29/// - Original: [Muon: An optimizer for hidden layers](https://kellerjordan.github.io/posts/muon/)
30/// - Moonshot: [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)
31#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
32pub enum AdjustLrFn {
33    /// Keller Jordan's original method: `lr * sqrt(max(1, A/B))`
34    ///
35    /// This scales the learning rate based on the aspect ratio of the weight matrix,
36    /// ensuring that tall matrices (more rows than columns) get proportionally larger
37    /// learning rates.
38    ///
39    /// # Example
40    ///
41    /// For a [1024, 512] matrix: `lr * sqrt(1024/512) = lr * 1.414`
42    #[default]
43    Original,
44
45    /// Moonshot's method: `lr * 0.2 * sqrt(max(A, B))`
46    ///
47    /// This method is designed to match AdamW's RMS, allowing Muon to directly reuse
48    /// learning rates and weight decay values tuned for AdamW without retuning.
49    ///
50    /// # Example
51    ///
52    /// For a [1024, 512] matrix: `lr * 0.2 * sqrt(1024) = lr * 6.4`
53    MatchRmsAdamW,
54}
55
56impl AdjustLrFn {
57    /// Calculate the learning rate adjustment ratio for a given parameter shape.
58    ///
59    /// # Arguments
60    ///
61    /// * `shape` - Parameter shape (uses first two dimensions)
62    ///
63    /// # Returns
64    ///
65    /// Adjustment ratio to multiply with the base learning rate
66    fn adjustment_ratio(&self, shape: &[usize]) -> f64 {
67        if shape.len() < 2 {
68            return 1.0;
69        }
70
71        let a = shape[0] as f64;
72        let b = shape[1] as f64;
73
74        match self {
75            Self::Original => {
76                // sqrt(max(1, A/B))
77                let ratio = a / b;
78                ratio.max(1.0).sqrt()
79            }
80            Self::MatchRmsAdamW => {
81                // 0.2 * sqrt(max(A, B))
82                0.2 * a.max(b).sqrt()
83            }
84        }
85    }
86}
87
88/// Muon configuration.
89///
90/// Muon is an optimizer specifically designed for 2D parameters of neural network
91/// hidden layers (weight matrices). Other parameters such as biases and embeddings
92/// should be optimized using a standard method such as AdamW.
93///
94/// # Learning Rate Adjustment
95///
96/// Muon adjusts the learning rate based on parameter shape to maintain consistent
97/// RMS across rectangular matrices. Two methods are available:
98///
99/// - **Original**: Uses `sqrt(max(1, A/B))` where A and B are the first two dimensions.
100///   This is Keller Jordan's method and is the default.
101///
102/// - **MatchRmsAdamW**: Uses `0.2 * sqrt(max(A, B))`. This is Moonshot's method
103///   designed to match AdamW's RMS, allowing direct reuse of AdamW hyperparameters.
104///
105/// # Example
106///
107/// ```ignore
108/// use burn_optim::{MuonConfig, AdjustLrFn};
109///
110/// // Using default (Original) method
111/// let optimizer = MuonConfig::new().init();
112///
113/// // Using MatchRmsAdamW for AdamW-compatible hyperparameters
114/// let optimizer = MuonConfig::new()
115///     .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
116///     .init();
117/// ```
118///
119/// # References
120///
121/// - [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/)
122/// - [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)
123/// - [PyTorch Implementation](https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py)
124/// - [Original Implementation](https://github.com/KellerJordan/Muon)
125#[derive(Config, Debug)]
126pub struct MuonConfig {
127    /// [Weight decay](WeightDecayConfig) config.
128    weight_decay: Option<WeightDecayConfig>,
129
130    /// [Momentum](MomentumConfig) config.
131    ///
132    /// Muon always uses momentum. Default configuration:
133    /// - momentum: 0.95
134    /// - dampening: 0.0
135    /// - nesterov: true
136    #[config(default = "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }")]
137    momentum: MomentumConfig,
138
139    /// Newton-Schulz iteration coefficients (a, b, c).
140    ///
141    /// These coefficients are selected to maximize the slope at zero for the
142    /// quintic iteration. Default values are from Keller Jordan's implementation.
143    #[config(default = "(3.4445, -4.775, 2.0315)")]
144    ns_coefficients: (f32, f32, f32),
145
146    /// Epsilon for numerical stability.
147    #[config(default = 1e-7)]
148    epsilon: f32,
149
150    /// Number of Newton-Schulz iteration steps.
151    #[config(default = 5)]
152    ns_steps: usize,
153
154    /// Learning rate adjustment method.
155    ///
156    /// Controls how the learning rate is adjusted based on parameter shape.
157    /// See [`AdjustLrFn`] for available methods.
158    #[config(default = "AdjustLrFn::Original")]
159    adjust_lr_fn: AdjustLrFn,
160}
161
162impl MuonConfig {
163    /// Build a [`Muon`] from the config.
164    pub fn build<B: Backend>(&self) -> Muon<B> {
165        let momentum = Momentum::new(&self.momentum);
166        let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
167
168        Muon {
169            momentum,
170            ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
171            weight_decay_penalty,
172            epsilon: self.epsilon,
173            adjust_lr_fn: self.adjust_lr_fn,
174        }
175    }
176
177    /// Initialize Muon optimizer.
178    ///
179    /// # Returns
180    ///
181    /// Returns an optimizer adaptor that can be used to optimize a module.
182    ///
183    /// # Example
184    ///
185    /// ```ignore
186    /// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
187    ///
188    /// // Basic configuration with default (Original) LR adjustment
189    /// let optimizer = MuonConfig::new()
190    ///     .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
191    ///     .init();
192    ///
193    /// // With AdamW-compatible settings using MatchRmsAdamW
194    /// let optimizer = MuonConfig::new()
195    ///     .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
196    ///     .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
197    ///     .init();
198    ///
199    /// // Custom momentum and NS settings
200    /// let optimizer = MuonConfig::new()
201    ///     .with_momentum(MomentumConfig {
202    ///         momentum: 0.9,
203    ///         dampening: 0.1,
204    ///         nesterov: false,
205    ///     })
206    ///     .with_ns_steps(7)
207    ///     .init();
208    /// ```
209    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
210        &self,
211    ) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
212        OptimizerAdaptor::from(self.build())
213    }
214}
215
216/// Parameters for Newton-Schulz orthogonalization.
217#[derive(Clone, Copy)]
218struct NewtonSchulzParams {
219    a: f32,
220    b: f32,
221    c: f32,
222    steps: usize,
223}
224
225impl NewtonSchulzParams {
226    fn new(coefficients: (f32, f32, f32), steps: usize) -> Self {
227        Self {
228            a: coefficients.0,
229            b: coefficients.1,
230            c: coefficients.2,
231            steps,
232        }
233    }
234}
235
236/// Muon optimizer.
237///
238/// Muon internally runs standard SGD-momentum, and then performs an orthogonalization
239/// post-processing step, in which each 2D parameter's update is replaced with the
240/// nearest orthogonal matrix. For efficient orthogonalization we use a Newton-Schulz
241/// iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.
242///
243/// # Important Notes
244///
245/// 1. **Only for 2D+ parameters**: Muon is designed for weight matrices. Use AdamW
246///    or SGD for biases, embeddings, and layer norms.
247///
248/// 2. **Learning rate adjustment**: Muon automatically adjusts the learning rate based
249///    on parameter shape. See [`AdjustLrFn`] for details.
250///
251/// 3. **Weight decay timing**: Unlike typical optimizers, Muon applies weight decay
252///    AFTER orthogonalization but uses the original (unadjusted) learning rate for it.
253#[derive(Clone)]
254pub struct Muon<B: Backend> {
255    momentum: Momentum<B>,
256    ns_params: NewtonSchulzParams,
257    weight_decay_penalty: Option<f32>,
258    epsilon: f32,
259    adjust_lr_fn: AdjustLrFn,
260}
261
262impl<B: Backend> Muon<B> {
263    /// Adjust learning rate based on parameter shape.
264    ///
265    /// # Arguments
266    ///
267    /// * `lr` - Base learning rate
268    /// * `shape` - Parameter shape (uses first two dimensions)
269    ///
270    /// # Returns
271    ///
272    /// Adjusted learning rate
273    ///
274    /// ```ignore
275    /// // For a [1024, 512] weight matrix with lr=0.01:
276    /// // Original: 0.01 * sqrt(1024/512) = 0.01 * 1.414 = 0.01414
277    /// // MatchRmsAdamW: 0.01 * 0.2 * sqrt(1024) = 0.01 * 0.2 * 32 = 0.064
278    /// ```
279    fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate {
280        lr * self.adjust_lr_fn.adjustment_ratio(shape)
281    }
282
283    /// Perform Newton-Schulz orthogonalization on a gradient tensor.
284    ///
285    /// This computes the zeroth power (orthogonalization) of the input matrix G
286    /// using a quintic Newton-Schulz iteration.
287    ///
288    /// # Algorithm
289    ///
290    /// 1. Transpose if tall matrix (A > B)
291    /// 2. Normalize: X = X / ||X||
292    /// 3. For k steps:
293    ///    - A = X @ X^T
294    ///    - B = b*A + c*A^2
295    ///    - X = a*X + B@X
296    /// 4. Transpose back if needed
297    ///
298    /// # References
299    ///
300    /// - Original: https://github.com/KellerJordan/Muon/blob/master/muon.py
301    /// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py
302    fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {
303        let shape = g.shape();
304        let dim_m2 = shape[D - 2];
305        let dim_m1 = shape[D - 1];
306
307        // Step 1: Transpose if tall matrix (more rows than columns)
308        let (mut x, needs_transpose) = if dim_m2 > dim_m1 {
309            (g.swap_dims(D - 2, D - 1), true)
310        } else {
311            (g, false)
312        };
313
314        // Step 2: Normalize by Frobenius norm
315        // X = X / (||X|| + epsilon)
316        let norm = x
317            .clone()
318            .powf_scalar(2.0)
319            .sum()
320            .sqrt()
321            .clamp_min(self.epsilon)
322            .unsqueeze();
323
324        x = x.div(norm);
325
326        // Step 3: Newton-Schulz iteration
327        // This is the quintic iteration with coefficients (a, b, c)
328        let NewtonSchulzParams { a, b, c, steps } = self.ns_params;
329
330        for _ in 0..steps {
331            // A = X @ X^T
332            let x_t = x.clone().swap_dims(D - 2, D - 1);
333            let a_matrix = x.clone().matmul(x_t);
334
335            // B = b*A + c*A@A
336            let a_squared = a_matrix.clone().matmul(a_matrix.clone());
337            let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c));
338
339            // X = a*X + B@X
340            x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone()));
341        }
342
343        // Step 4: Restore transpose if it was a tall matrix
344        if needs_transpose {
345            x = x.swap_dims(D - 2, D - 1);
346        }
347
348        x
349    }
350}
351
352/// Muon state.
353#[derive(Record, Clone, new)]
354pub struct MuonState<B: Backend, const D: usize> {
355    /// Current momentum state
356    pub momentum: MomentumState<B, D>,
357}
358
359impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
360    type State<const D: usize> = MuonState<B, D>;
361
362    /// Perform a single Muon optimization step.
363    ///
364    /// # Algorithm
365    ///
366    /// 1. Apply momentum to gradient
367    /// 2. Orthogonalize update via Newton-Schulz
368    /// 3. Adjust learning rate based on parameter shape
369    /// 4. Apply weight decay (using original lr)
370    /// 5. Update parameter (using adjusted lr)
371    ///
372    /// # Notes
373    ///
374    /// Unlike typical optimizers, the weight decay and parameter update use
375    /// different learning rates:
376    /// - Weight decay uses the original `lr`
377    /// - Parameter update uses the shape-adjusted `lr`
378    ///
379    /// # Panics
380    /// This function will panic if the input tensors are not 2D.
381    fn step<const D: usize>(
382        &self,
383        lr: LearningRate,
384        tensor: Tensor<B, D>,
385        grad: Tensor<B, D>,
386        state: Option<Self::State<D>>,
387    ) -> (Tensor<B, D>, Option<Self::State<D>>) {
388        assert!(
389            D == 2,
390            "Newton-Schulz iteration requires 2D tensors, got {}D",
391            D
392        );
393
394        // Step 1: Apply momentum
395        let state_momentum = state.map(|s| s.momentum);
396        let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);
397
398        // Step 2: Orthogonalize via Newton-Schulz
399        let update = self.zeropower_via_newtonschulz(grad);
400
401        // Step 3: Adjust learning rate based on parameter shape
402        let adjusted_lr = self.adjust_lr(lr, &tensor.shape());
403
404        // Step 4: Apply weight decay (using ORIGINAL lr, not adjusted)
405        // Muon applies weight decay AFTER orthogonalization
406        let tensor = if let Some(penalty) = self.weight_decay_penalty {
407            let decay_factor = 1.0 - lr * penalty as f64;
408            tensor.mul_scalar(decay_factor)
409        } else {
410            tensor
411        };
412
413        // Step 5: Update parameter (using ADJUSTED lr)
414        let delta = update.mul_scalar(adjusted_lr);
415        let new_state = MuonState::new(new_momentum_state);
416
417        (tensor - delta, Some(new_state))
418    }
419
420    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
421        state.momentum = state.momentum.to_device(device);
422        state
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::TestAutodiffBackend;
430    use crate::{GradientsParams, Optimizer};
431    use burn::module::{Module, Param};
432    use burn::tensor::{Distribution, Tensor, TensorData};
433    use burn_nn::{Linear, LinearConfig, LinearRecord};
434
435    type TestBackend = burn_flex::Flex;
436
437    const TOLERANCE: f64 = 1e-8;
438
439    fn given_linear_layer_no_bias(weight: TensorData) -> Linear<TestAutodiffBackend> {
440        let device = Default::default();
441        let record = LinearRecord {
442            weight: Param::from_data(weight, &device),
443            bias: None, //No bias for Muon optimizer
444        };
445
446        LinearConfig::new(4, 4)
447            .with_bias(false)
448            .init(&device)
449            .load_record(record)
450    }
451
452    #[test]
453    fn test_adjust_lr_fn_original() {
454        let method = AdjustLrFn::Original;
455
456        // Square matrix [512, 512] -> sqrt(1) = 1.0
457        let ratio = method.adjustment_ratio(&[512, 512]);
458        assert!((ratio - 1.0).abs() < TOLERANCE);
459
460        // Tall matrix [1024, 512] -> sqrt(2) ≈ 1.414
461        let ratio = method.adjustment_ratio(&[1024, 512]);
462        let expected = (2.0f64).sqrt();
463        assert!((ratio - expected).abs() < TOLERANCE);
464
465        // Wide matrix [512, 1024] -> max(1, 0.5) = 1.0
466        let ratio = method.adjustment_ratio(&[512, 1024]);
467        assert!((ratio - 1.0).abs() < TOLERANCE);
468    }
469
470    #[test]
471    fn test_adjust_lr_fn_match_rms_adamw() {
472        let method = AdjustLrFn::MatchRmsAdamW;
473
474        // [1024, 512] -> 0.2 * sqrt(1024) = 6.4
475        let ratio = method.adjustment_ratio(&[1024, 512]);
476        let expected = 0.2 * 1024.0f64.sqrt();
477        assert!((ratio - expected).abs() < TOLERANCE);
478
479        // [512, 512] -> 0.2 * sqrt(512) ≈ 4.525
480        let ratio = method.adjustment_ratio(&[512, 512]);
481        let expected = 0.2 * 512.0f64.sqrt();
482        assert!((ratio - expected).abs() < TOLERANCE);
483    }
484
485    #[test]
486    #[should_panic(expected = "Newton-Schulz iteration requires 2D tensors, got 1D")]
487    fn test_1d_tensor_panics() {
488        let device = Default::default();
489        let config = MuonConfig::new();
490        let optim: Muon<TestBackend> = Muon {
491            momentum: Momentum::new(&config.momentum),
492            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
493            weight_decay_penalty: None,
494            epsilon: config.epsilon,
495            adjust_lr_fn: config.adjust_lr_fn,
496        };
497
498        let tensor_1d = Tensor::<TestBackend, 1>::zeros([512], &device);
499        let grad_1d = Tensor::<TestBackend, 1>::ones([512], &device);
500
501        let _ = optim.step(0.01, tensor_1d, grad_1d, None);
502    }
503
504    #[test]
505    fn test_muon_optimizer_save_load_state() {
506        let device = Default::default();
507        // Use Linear layer WITHOUT bias for Muon optimizer
508        let linear = LinearConfig::new(6, 6)
509            .with_bias(false) // No bias - only 2D weight matrix
510            .init::<TestAutodiffBackend>(&device);
511
512        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
513
514        let mut optimizer =
515            MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
516        let grads = linear.forward(x).backward();
517        let grads = GradientsParams::from_grads(grads, &linear);
518        let _linear = optimizer.step(0.01, linear, grads);
519
520        let state_before = optimizer.to_record();
521        let state_before_copy = optimizer.to_record();
522
523        let optimizer_new =
524            MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
525        let optimizer_loaded = optimizer_new.load_record(state_before_copy);
526        let state_after = optimizer_loaded.to_record();
527
528        assert_eq!(state_before.len(), state_after.len());
529    }
530
531    #[test]
532    fn test_muon_with_weight_decay() {
533        let device = Default::default();
534        // Create Linear layer WITHOUT bias for Muon
535        let linear = given_linear_layer_no_bias(TensorData::from([
536            [1.0, 1.0, 1.0, 1.0],
537            [1.0, 1.0, 1.0, 1.0],
538            [1.0, 1.0, 1.0, 1.0],
539            [1.0, 1.0, 1.0, 1.0],
540        ]));
541
542        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
543            [[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],
544            &device,
545        )
546        .require_grad();
547
548        let mut optimizer = MuonConfig::new()
549            .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
550            .init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
551
552        let grads = linear.forward(x).backward();
553        let grads = GradientsParams::from_grads(grads, &linear);
554        let linear = optimizer.step(0.01, linear, grads);
555
556        let state = linear.into_record();
557        let weight = state.weight.to_data();
558
559        for val in weight.as_slice::<f32>().unwrap() {
560            assert!(
561                *val < 1.0,
562                "Weight should be reduced by weight decay, got {}",
563                val
564            );
565        }
566    }
567
568    #[test]
569    fn test_newton_schulz_orthogonalization() {
570        let device = Default::default();
571        let matrix = Tensor::<TestBackend, 2>::from_floats([[1.0, 0.5], [0.5, 1.0]], &device);
572
573        let config = MuonConfig::new();
574        let muon: Muon<TestBackend> = Muon {
575            momentum: Momentum::new(&config.momentum),
576            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
577            weight_decay_penalty: None,
578            epsilon: config.epsilon,
579            adjust_lr_fn: config.adjust_lr_fn,
580        };
581
582        let orthogonalized = muon.zeropower_via_newtonschulz(matrix);
583        let o_t = orthogonalized.clone().transpose();
584        let product = orthogonalized.matmul(o_t);
585
586        let data = product.into_data();
587        let values = data.as_slice::<f32>().unwrap();
588
589        assert!(
590            (values[0] - 1.0).abs() < 0.1,
591            "Product[0,0] should be ~1.0, got {}",
592            values[0]
593        );
594        assert!(
595            (values[3] - 1.0).abs() < 0.1,
596            "Product[1,1] should be ~1.0, got {}",
597            values[3]
598        );
599    }
600
601    #[test]
602    fn test_tall_matrix_transpose() {
603        // Test that tall matrices (A > B) are transposed during Newton-Schulz iteration
604        // and then transposed back
605        let device = Default::default();
606
607        // Create a tall matrix: [8, 4] (more rows than columns)
608        let tall_matrix = Tensor::<TestBackend, 2>::from_floats(
609            [
610                [1.0, 0.5, 0.3, 0.2],
611                [0.5, 1.0, 0.4, 0.1],
612                [0.3, 0.4, 1.0, 0.5],
613                [0.2, 0.1, 0.5, 1.0],
614                [0.1, 0.2, 0.3, 0.4],
615                [0.4, 0.3, 0.2, 0.1],
616                [0.2, 0.4, 0.1, 0.3],
617                [0.3, 0.1, 0.4, 0.2],
618            ],
619            &device,
620        );
621
622        let config = MuonConfig::new();
623        let muon: Muon<TestBackend> = Muon {
624            momentum: Momentum::new(&config.momentum),
625            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
626            weight_decay_penalty: None,
627            epsilon: config.epsilon,
628            adjust_lr_fn: config.adjust_lr_fn,
629        };
630
631        // Perform Newton-Schulz orthogonalization
632        let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone());
633
634        // Verify shape is preserved (should be transposed internally but returned in original shape)
635        let original_shape = tall_matrix.shape();
636        let result_shape = orthogonalized.shape();
637        assert_eq!(
638            original_shape.dims::<2>(),
639            result_shape.dims::<2>(),
640            "Shape should be preserved: [8, 4]"
641        );
642
643        // Verify output is different from input (orthogonalization happened)
644        let original_data = tall_matrix.into_data();
645        let result_data = orthogonalized.into_data();
646        assert_ne!(
647            original_data.as_slice::<f32>().unwrap(),
648            result_data.as_slice::<f32>().unwrap(),
649            "Orthogonalized matrix should differ from input"
650        );
651
652        // For comparison, test a wide matrix [4, 8] should NOT be transposed
653        let wide_matrix = Tensor::<TestBackend, 2>::from_floats(
654            [
655                [1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3],
656                [0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1],
657                [0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4],
658                [0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2],
659            ],
660            &device,
661        );
662
663        let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone());
664
665        // Verify wide matrix shape is also preserved
666        let wide_original_shape = wide_matrix.shape();
667        let wide_result_shape = orthogonalized_wide.shape();
668        assert_eq!(
669            wide_original_shape.dims::<2>(),
670            wide_result_shape.dims::<2>(),
671            "Wide matrix shape should be preserved: [4, 8]"
672        );
673    }
674
675    #[test]
676    fn test_zero_gradient() {
677        // Test that Muon handles zero gradients gracefully
678        let device = Default::default();
679
680        let tensor = Tensor::<TestBackend, 2>::from_floats(
681            [
682                [1.0, 0.5, 0.3, 0.2],
683                [0.5, 1.0, 0.4, 0.1],
684                [0.3, 0.4, 1.0, 0.5],
685                [0.2, 0.1, 0.5, 1.0],
686            ],
687            &device,
688        );
689
690        // Zero gradient - all zeros
691        let zero_grad = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
692
693        let config = MuonConfig::new();
694        let muon: Muon<TestBackend> = Muon {
695            momentum: Momentum::new(&config.momentum),
696            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
697            weight_decay_penalty: None,
698            epsilon: config.epsilon,
699            adjust_lr_fn: config.adjust_lr_fn,
700        };
701
702        // Should not panic or produce NaN
703        let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None);
704
705        // Verify state was created
706        assert!(state.is_some());
707
708        // With zero gradient and no weight decay, tensor should remain unchanged
709        let original_data = tensor.into_data();
710        let updated_data = updated_tensor.clone().into_data();
711
712        let original_vals = original_data.as_slice::<f32>().unwrap();
713        let updated_vals = updated_data.as_slice::<f32>().unwrap();
714
715        for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) {
716            assert!(
717                (orig - upd).abs() < 1e-6,
718                "With zero gradient, tensor should remain unchanged (or very close)"
719            );
720        }
721
722        // Verify no NaN values
723        for val in updated_vals {
724            assert!(
725                !val.is_nan(),
726                "Result should not contain NaN values with zero gradient"
727            );
728        }
729
730        // Test with weight decay - should still work
731        let muon_with_decay: Muon<TestBackend> = Muon {
732            momentum: Momentum::new(&config.momentum),
733            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
734            weight_decay_penalty: Some(0.01),
735            epsilon: config.epsilon,
736            adjust_lr_fn: config.adjust_lr_fn,
737        };
738
739        let tensor2 = Tensor::<TestBackend, 2>::from_floats(
740            [
741                [1.0, 0.5, 0.3, 0.2],
742                [0.5, 1.0, 0.4, 0.1],
743                [0.3, 0.4, 1.0, 0.5],
744                [0.2, 0.1, 0.5, 1.0],
745            ],
746            &device,
747        );
748        let zero_grad2 = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
749
750        let (updated_tensor_decay, _) =
751            muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None);
752
753        // With zero gradient but with weight decay, tensor should be slightly reduced
754        let updated_decay_data = updated_tensor_decay.into_data();
755        let updated_decay_vals = updated_decay_data.as_slice::<f32>().unwrap();
756
757        for val in updated_decay_vals {
758            assert!(
759                !val.is_nan(),
760                "Result should not contain NaN with zero gradient and weight decay"
761            );
762        }
763
764        // With weight decay, values should be slightly smaller than original
765        let original_vals2 = tensor2.into_data().as_slice::<f32>().unwrap().to_vec();
766        for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) {
767            if orig.abs() > 1e-6 {
768                // Non-zero values should be reduced by weight decay
769                assert!(
770                    upd.abs() < orig.abs(),
771                    "Weight decay should reduce magnitude: original={}, updated={}",
772                    orig,
773                    upd
774                );
775            }
776        }
777    }
778}