1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8use serde::{Deserialize, Serialize};
9
10use super::{
11 SimpleOptimizer,
12 adaptor::OptimizerAdaptor,
13 decay::WeightDecayConfig,
14 momentum::{Momentum, MomentumConfig, MomentumState},
15};
16use crate::LearningRate;
17
18#[cfg(not(feature = "std"))]
19#[allow(unused_imports)]
20use num_traits::Float as _;
21
22#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
32pub enum AdjustLrFn {
33 #[default]
43 Original,
44
45 MatchRmsAdamW,
54}
55
56impl AdjustLrFn {
57 fn adjustment_ratio(&self, shape: &[usize]) -> f64 {
67 if shape.len() < 2 {
68 return 1.0;
69 }
70
71 let a = shape[0] as f64;
72 let b = shape[1] as f64;
73
74 match self {
75 Self::Original => {
76 let ratio = a / b;
78 ratio.max(1.0).sqrt()
79 }
80 Self::MatchRmsAdamW => {
81 0.2 * a.max(b).sqrt()
83 }
84 }
85 }
86}
87
88#[derive(Config, Debug)]
126pub struct MuonConfig {
127 weight_decay: Option<WeightDecayConfig>,
129
130 #[config(default = "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }")]
137 momentum: MomentumConfig,
138
139 #[config(default = "(3.4445, -4.775, 2.0315)")]
144 ns_coefficients: (f32, f32, f32),
145
146 #[config(default = 1e-7)]
148 epsilon: f32,
149
150 #[config(default = 5)]
152 ns_steps: usize,
153
154 #[config(default = "AdjustLrFn::Original")]
159 adjust_lr_fn: AdjustLrFn,
160}
161
162impl MuonConfig {
163 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
196 &self,
197 ) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
198 let momentum = Momentum::new(&self.momentum);
199 let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
200
201 let optim = Muon {
202 momentum,
203 ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
204 weight_decay_penalty,
205 epsilon: self.epsilon,
206 adjust_lr_fn: self.adjust_lr_fn,
207 };
208
209 OptimizerAdaptor::from(optim)
210 }
211}
212
213#[derive(Clone, Copy)]
215struct NewtonSchulzParams {
216 a: f32,
217 b: f32,
218 c: f32,
219 steps: usize,
220}
221
222impl NewtonSchulzParams {
223 fn new(coefficients: (f32, f32, f32), steps: usize) -> Self {
224 Self {
225 a: coefficients.0,
226 b: coefficients.1,
227 c: coefficients.2,
228 steps,
229 }
230 }
231}
232
233#[derive(Clone)]
251pub struct Muon<B: Backend> {
252 momentum: Momentum<B>,
253 ns_params: NewtonSchulzParams,
254 weight_decay_penalty: Option<f32>,
255 epsilon: f32,
256 adjust_lr_fn: AdjustLrFn,
257}
258
259impl<B: Backend> Muon<B> {
260 fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate {
277 lr * self.adjust_lr_fn.adjustment_ratio(shape)
278 }
279
280 fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {
300 let shape = g.shape();
301 let dim_m2 = shape[D - 2];
302 let dim_m1 = shape[D - 1];
303
304 let (mut x, needs_transpose) = if dim_m2 > dim_m1 {
306 (g.swap_dims(D - 2, D - 1), true)
307 } else {
308 (g, false)
309 };
310
311 let norm = x
314 .clone()
315 .powf_scalar(2.0)
316 .sum()
317 .sqrt()
318 .clamp_min(self.epsilon)
319 .unsqueeze();
320
321 x = x.div(norm);
322
323 let NewtonSchulzParams { a, b, c, steps } = self.ns_params;
326
327 for _ in 0..steps {
328 let x_t = x.clone().swap_dims(D - 2, D - 1);
330 let a_matrix = x.clone().matmul(x_t);
331
332 let a_squared = a_matrix.clone().matmul(a_matrix.clone());
334 let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c));
335
336 x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone()));
338 }
339
340 if needs_transpose {
342 x = x.swap_dims(D - 2, D - 1);
343 }
344
345 x
346 }
347}
348
349#[derive(Record, Clone, new)]
351pub struct MuonState<B: Backend, const D: usize> {
352 pub momentum: MomentumState<B, D>,
354}
355
356impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
357 type State<const D: usize> = MuonState<B, D>;
358
359 fn step<const D: usize>(
379 &self,
380 lr: LearningRate,
381 tensor: Tensor<B, D>,
382 grad: Tensor<B, D>,
383 state: Option<Self::State<D>>,
384 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
385 assert!(
386 D == 2,
387 "Newton-Schulz iteration requires 2D tensors, got {}D",
388 D
389 );
390
391 let state_momentum = state.map(|s| s.momentum);
393 let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);
394
395 let update = self.zeropower_via_newtonschulz(grad);
397
398 let adjusted_lr = self.adjust_lr(lr, &tensor.shape());
400
401 let tensor = if let Some(penalty) = self.weight_decay_penalty {
404 let decay_factor = 1.0 - lr * penalty as f64;
405 tensor.mul_scalar(decay_factor)
406 } else {
407 tensor
408 };
409
410 let delta = update.mul_scalar(adjusted_lr);
412 let new_state = MuonState::new(new_momentum_state);
413
414 (tensor - delta, Some(new_state))
415 }
416
417 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
418 state.momentum = state.momentum.to_device(device);
419 state
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use crate::TestAutodiffBackend;
427 use crate::{GradientsParams, Optimizer};
428 use burn::module::{Module, Param};
429 use burn::tensor::{Distribution, Tensor, TensorData};
430 use burn_nn::{Linear, LinearConfig, LinearRecord};
431
432 type TestBackend = burn_ndarray::NdArray<f32>;
433
434 const TOLERANCE: f64 = 1e-8;
435
436 fn given_linear_layer_no_bias(weight: TensorData) -> Linear<TestAutodiffBackend> {
437 let device = Default::default();
438 let record = LinearRecord {
439 weight: Param::from_data(weight, &device),
440 bias: None, };
442
443 LinearConfig::new(4, 4)
444 .with_bias(false)
445 .init(&device)
446 .load_record(record)
447 }
448
449 #[test]
450 fn test_adjust_lr_fn_original() {
451 let method = AdjustLrFn::Original;
452
453 let ratio = method.adjustment_ratio(&[512, 512]);
455 assert!((ratio - 1.0).abs() < TOLERANCE);
456
457 let ratio = method.adjustment_ratio(&[1024, 512]);
459 let expected = (2.0f64).sqrt();
460 assert!((ratio - expected).abs() < TOLERANCE);
461
462 let ratio = method.adjustment_ratio(&[512, 1024]);
464 assert!((ratio - 1.0).abs() < TOLERANCE);
465 }
466
467 #[test]
468 fn test_adjust_lr_fn_match_rms_adamw() {
469 let method = AdjustLrFn::MatchRmsAdamW;
470
471 let ratio = method.adjustment_ratio(&[1024, 512]);
473 let expected = 0.2 * 1024.0f64.sqrt();
474 assert!((ratio - expected).abs() < TOLERANCE);
475
476 let ratio = method.adjustment_ratio(&[512, 512]);
478 let expected = 0.2 * 512.0f64.sqrt();
479 assert!((ratio - expected).abs() < TOLERANCE);
480 }
481
482 #[test]
483 #[should_panic(expected = "Newton-Schulz iteration requires 2D tensors, got 1D")]
484 fn test_1d_tensor_panics() {
485 let device = Default::default();
486 let config = MuonConfig::new();
487 let optim: Muon<TestBackend> = Muon {
488 momentum: Momentum::new(&config.momentum),
489 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
490 weight_decay_penalty: None,
491 epsilon: config.epsilon,
492 adjust_lr_fn: config.adjust_lr_fn,
493 };
494
495 let tensor_1d = Tensor::<TestBackend, 1>::zeros([512], &device);
496 let grad_1d = Tensor::<TestBackend, 1>::ones([512], &device);
497
498 let _ = optim.step(0.01, tensor_1d, grad_1d, None);
499 }
500
501 #[test]
502 fn test_muon_optimizer_save_load_state() {
503 let device = Default::default();
504 let linear = LinearConfig::new(6, 6)
506 .with_bias(false) .init::<TestAutodiffBackend>(&device);
508
509 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
510
511 let mut optimizer =
512 MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
513 let grads = linear.forward(x).backward();
514 let grads = GradientsParams::from_grads(grads, &linear);
515 let _linear = optimizer.step(0.01, linear, grads);
516
517 let state_before = optimizer.to_record();
518 let state_before_copy = optimizer.to_record();
519
520 let optimizer_new =
521 MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
522 let optimizer_loaded = optimizer_new.load_record(state_before_copy);
523 let state_after = optimizer_loaded.to_record();
524
525 assert_eq!(state_before.len(), state_after.len());
526 }
527
528 #[test]
529 fn test_muon_with_weight_decay() {
530 let device = Default::default();
531 let linear = given_linear_layer_no_bias(TensorData::from([
533 [1.0, 1.0, 1.0, 1.0],
534 [1.0, 1.0, 1.0, 1.0],
535 [1.0, 1.0, 1.0, 1.0],
536 [1.0, 1.0, 1.0, 1.0],
537 ]));
538
539 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
540 [[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],
541 &device,
542 )
543 .require_grad();
544
545 let mut optimizer = MuonConfig::new()
546 .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
547 .init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
548
549 let grads = linear.forward(x).backward();
550 let grads = GradientsParams::from_grads(grads, &linear);
551 let linear = optimizer.step(0.01, linear, grads);
552
553 let state = linear.into_record();
554 let weight = state.weight.to_data();
555
556 for val in weight.as_slice::<f32>().unwrap() {
557 assert!(
558 *val < 1.0,
559 "Weight should be reduced by weight decay, got {}",
560 val
561 );
562 }
563 }
564
565 #[test]
566 fn test_newton_schulz_orthogonalization() {
567 let device = Default::default();
568 let matrix = Tensor::<TestBackend, 2>::from_floats([[1.0, 0.5], [0.5, 1.0]], &device);
569
570 let config = MuonConfig::new();
571 let muon: Muon<TestBackend> = Muon {
572 momentum: Momentum::new(&config.momentum),
573 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
574 weight_decay_penalty: None,
575 epsilon: config.epsilon,
576 adjust_lr_fn: config.adjust_lr_fn,
577 };
578
579 let orthogonalized = muon.zeropower_via_newtonschulz(matrix);
580 let o_t = orthogonalized.clone().transpose();
581 let product = orthogonalized.matmul(o_t);
582
583 let data = product.into_data();
584 let values = data.as_slice::<f32>().unwrap();
585
586 assert!(
587 (values[0] - 1.0).abs() < 0.1,
588 "Product[0,0] should be ~1.0, got {}",
589 values[0]
590 );
591 assert!(
592 (values[3] - 1.0).abs() < 0.1,
593 "Product[1,1] should be ~1.0, got {}",
594 values[3]
595 );
596 }
597
598 #[test]
599 fn test_tall_matrix_transpose() {
600 let device = Default::default();
603
604 let tall_matrix = Tensor::<TestBackend, 2>::from_floats(
606 [
607 [1.0, 0.5, 0.3, 0.2],
608 [0.5, 1.0, 0.4, 0.1],
609 [0.3, 0.4, 1.0, 0.5],
610 [0.2, 0.1, 0.5, 1.0],
611 [0.1, 0.2, 0.3, 0.4],
612 [0.4, 0.3, 0.2, 0.1],
613 [0.2, 0.4, 0.1, 0.3],
614 [0.3, 0.1, 0.4, 0.2],
615 ],
616 &device,
617 );
618
619 let config = MuonConfig::new();
620 let muon: Muon<TestBackend> = Muon {
621 momentum: Momentum::new(&config.momentum),
622 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
623 weight_decay_penalty: None,
624 epsilon: config.epsilon,
625 adjust_lr_fn: config.adjust_lr_fn,
626 };
627
628 let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone());
630
631 let original_shape = tall_matrix.shape();
633 let result_shape = orthogonalized.shape();
634 assert_eq!(
635 original_shape.dims::<2>(),
636 result_shape.dims::<2>(),
637 "Shape should be preserved: [8, 4]"
638 );
639
640 let original_data = tall_matrix.into_data();
642 let result_data = orthogonalized.into_data();
643 assert_ne!(
644 original_data.as_slice::<f32>().unwrap(),
645 result_data.as_slice::<f32>().unwrap(),
646 "Orthogonalized matrix should differ from input"
647 );
648
649 let wide_matrix = Tensor::<TestBackend, 2>::from_floats(
651 [
652 [1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3],
653 [0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1],
654 [0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4],
655 [0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2],
656 ],
657 &device,
658 );
659
660 let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone());
661
662 let wide_original_shape = wide_matrix.shape();
664 let wide_result_shape = orthogonalized_wide.shape();
665 assert_eq!(
666 wide_original_shape.dims::<2>(),
667 wide_result_shape.dims::<2>(),
668 "Wide matrix shape should be preserved: [4, 8]"
669 );
670 }
671
672 #[test]
673 fn test_zero_gradient() {
674 let device = Default::default();
676
677 let tensor = Tensor::<TestBackend, 2>::from_floats(
678 [
679 [1.0, 0.5, 0.3, 0.2],
680 [0.5, 1.0, 0.4, 0.1],
681 [0.3, 0.4, 1.0, 0.5],
682 [0.2, 0.1, 0.5, 1.0],
683 ],
684 &device,
685 );
686
687 let zero_grad = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
689
690 let config = MuonConfig::new();
691 let muon: Muon<TestBackend> = Muon {
692 momentum: Momentum::new(&config.momentum),
693 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
694 weight_decay_penalty: None,
695 epsilon: config.epsilon,
696 adjust_lr_fn: config.adjust_lr_fn,
697 };
698
699 let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None);
701
702 assert!(state.is_some());
704
705 let original_data = tensor.into_data();
707 let updated_data = updated_tensor.clone().into_data();
708
709 let original_vals = original_data.as_slice::<f32>().unwrap();
710 let updated_vals = updated_data.as_slice::<f32>().unwrap();
711
712 for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) {
713 assert!(
714 (orig - upd).abs() < 1e-6,
715 "With zero gradient, tensor should remain unchanged (or very close)"
716 );
717 }
718
719 for val in updated_vals {
721 assert!(
722 !val.is_nan(),
723 "Result should not contain NaN values with zero gradient"
724 );
725 }
726
727 let muon_with_decay: Muon<TestBackend> = Muon {
729 momentum: Momentum::new(&config.momentum),
730 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
731 weight_decay_penalty: Some(0.01),
732 epsilon: config.epsilon,
733 adjust_lr_fn: config.adjust_lr_fn,
734 };
735
736 let tensor2 = Tensor::<TestBackend, 2>::from_floats(
737 [
738 [1.0, 0.5, 0.3, 0.2],
739 [0.5, 1.0, 0.4, 0.1],
740 [0.3, 0.4, 1.0, 0.5],
741 [0.2, 0.1, 0.5, 1.0],
742 ],
743 &device,
744 );
745 let zero_grad2 = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
746
747 let (updated_tensor_decay, _) =
748 muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None);
749
750 let updated_decay_data = updated_tensor_decay.into_data();
752 let updated_decay_vals = updated_decay_data.as_slice::<f32>().unwrap();
753
754 for val in updated_decay_vals {
755 assert!(
756 !val.is_nan(),
757 "Result should not contain NaN with zero gradient and weight decay"
758 );
759 }
760
761 let original_vals2 = tensor2.into_data().as_slice::<f32>().unwrap().to_vec();
763 for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) {
764 if orig.abs() > 1e-6 {
765 assert!(
767 upd.abs() < orig.abs(),
768 "Weight decay should reduce magnitude: original={}, updated={}",
769 orig,
770 upd
771 );
772 }
773 }
774 }
775}