1use burn_core as burn;
2
3use burn::{module::AutodiffModule, record::Record};
4
5use burn::config::Config;
6use burn::tensor::{Tensor, backend::AutodiffBackend};
7use burn::tensor::{backend::Backend, ops::Device};
8use serde::{Deserialize, Serialize};
9
10use super::{
11 SimpleOptimizer,
12 adaptor::OptimizerAdaptor,
13 decay::WeightDecayConfig,
14 momentum::{Momentum, MomentumConfig, MomentumState},
15};
16use crate::LearningRate;
17
18#[cfg(not(feature = "std"))]
19#[allow(unused_imports)]
20use num_traits::Float as _;
21
22#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
32pub enum AdjustLrFn {
33 #[default]
43 Original,
44
45 MatchRmsAdamW,
54}
55
56impl AdjustLrFn {
57 fn adjustment_ratio(&self, shape: &[usize]) -> f64 {
67 if shape.len() < 2 {
68 return 1.0;
69 }
70
71 let a = shape[0] as f64;
72 let b = shape[1] as f64;
73
74 match self {
75 Self::Original => {
76 let ratio = a / b;
78 ratio.max(1.0).sqrt()
79 }
80 Self::MatchRmsAdamW => {
81 0.2 * a.max(b).sqrt()
83 }
84 }
85 }
86}
87
88#[derive(Config, Debug)]
126pub struct MuonConfig {
127 weight_decay: Option<WeightDecayConfig>,
129
130 #[config(default = "MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }")]
137 momentum: MomentumConfig,
138
139 #[config(default = "(3.4445, -4.775, 2.0315)")]
144 ns_coefficients: (f32, f32, f32),
145
146 #[config(default = 1e-7)]
148 epsilon: f32,
149
150 #[config(default = 5)]
152 ns_steps: usize,
153
154 #[config(default = "AdjustLrFn::Original")]
159 adjust_lr_fn: AdjustLrFn,
160}
161
162impl MuonConfig {
163 pub fn build<B: Backend>(&self) -> Muon<B> {
165 let momentum = Momentum::new(&self.momentum);
166 let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
167
168 Muon {
169 momentum,
170 ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
171 weight_decay_penalty,
172 epsilon: self.epsilon,
173 adjust_lr_fn: self.adjust_lr_fn,
174 }
175 }
176
177 pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
210 &self,
211 ) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
212 OptimizerAdaptor::from(self.build())
213 }
214}
215
216#[derive(Clone, Copy)]
218struct NewtonSchulzParams {
219 a: f32,
220 b: f32,
221 c: f32,
222 steps: usize,
223}
224
225impl NewtonSchulzParams {
226 fn new(coefficients: (f32, f32, f32), steps: usize) -> Self {
227 Self {
228 a: coefficients.0,
229 b: coefficients.1,
230 c: coefficients.2,
231 steps,
232 }
233 }
234}
235
236#[derive(Clone)]
254pub struct Muon<B: Backend> {
255 momentum: Momentum<B>,
256 ns_params: NewtonSchulzParams,
257 weight_decay_penalty: Option<f32>,
258 epsilon: f32,
259 adjust_lr_fn: AdjustLrFn,
260}
261
262impl<B: Backend> Muon<B> {
263 fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate {
280 lr * self.adjust_lr_fn.adjustment_ratio(shape)
281 }
282
283 fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {
303 let shape = g.shape();
304 let dim_m2 = shape[D - 2];
305 let dim_m1 = shape[D - 1];
306
307 let (mut x, needs_transpose) = if dim_m2 > dim_m1 {
309 (g.swap_dims(D - 2, D - 1), true)
310 } else {
311 (g, false)
312 };
313
314 let norm = x
317 .clone()
318 .powf_scalar(2.0)
319 .sum()
320 .sqrt()
321 .clamp_min(self.epsilon)
322 .unsqueeze();
323
324 x = x.div(norm);
325
326 let NewtonSchulzParams { a, b, c, steps } = self.ns_params;
329
330 for _ in 0..steps {
331 let x_t = x.clone().swap_dims(D - 2, D - 1);
333 let a_matrix = x.clone().matmul(x_t);
334
335 let a_squared = a_matrix.clone().matmul(a_matrix.clone());
337 let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c));
338
339 x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone()));
341 }
342
343 if needs_transpose {
345 x = x.swap_dims(D - 2, D - 1);
346 }
347
348 x
349 }
350}
351
352#[derive(Record, Clone, new)]
354pub struct MuonState<B: Backend, const D: usize> {
355 pub momentum: MomentumState<B, D>,
357}
358
359impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
360 type State<const D: usize> = MuonState<B, D>;
361
362 fn step<const D: usize>(
382 &self,
383 lr: LearningRate,
384 tensor: Tensor<B, D>,
385 grad: Tensor<B, D>,
386 state: Option<Self::State<D>>,
387 ) -> (Tensor<B, D>, Option<Self::State<D>>) {
388 assert!(
389 D == 2,
390 "Newton-Schulz iteration requires 2D tensors, got {}D",
391 D
392 );
393
394 let state_momentum = state.map(|s| s.momentum);
396 let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);
397
398 let update = self.zeropower_via_newtonschulz(grad);
400
401 let adjusted_lr = self.adjust_lr(lr, &tensor.shape());
403
404 let tensor = if let Some(penalty) = self.weight_decay_penalty {
407 let decay_factor = 1.0 - lr * penalty as f64;
408 tensor.mul_scalar(decay_factor)
409 } else {
410 tensor
411 };
412
413 let delta = update.mul_scalar(adjusted_lr);
415 let new_state = MuonState::new(new_momentum_state);
416
417 (tensor - delta, Some(new_state))
418 }
419
420 fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {
421 state.momentum = state.momentum.to_device(device);
422 state
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::TestAutodiffBackend;
430 use crate::{GradientsParams, Optimizer};
431 use burn::module::{Module, Param};
432 use burn::tensor::{Distribution, Tensor, TensorData};
433 use burn_nn::{Linear, LinearConfig, LinearRecord};
434
435 type TestBackend = burn_flex::Flex;
436
437 const TOLERANCE: f64 = 1e-8;
438
439 fn given_linear_layer_no_bias(weight: TensorData) -> Linear<TestAutodiffBackend> {
440 let device = Default::default();
441 let record = LinearRecord {
442 weight: Param::from_data(weight, &device),
443 bias: None, };
445
446 LinearConfig::new(4, 4)
447 .with_bias(false)
448 .init(&device)
449 .load_record(record)
450 }
451
452 #[test]
453 fn test_adjust_lr_fn_original() {
454 let method = AdjustLrFn::Original;
455
456 let ratio = method.adjustment_ratio(&[512, 512]);
458 assert!((ratio - 1.0).abs() < TOLERANCE);
459
460 let ratio = method.adjustment_ratio(&[1024, 512]);
462 let expected = (2.0f64).sqrt();
463 assert!((ratio - expected).abs() < TOLERANCE);
464
465 let ratio = method.adjustment_ratio(&[512, 1024]);
467 assert!((ratio - 1.0).abs() < TOLERANCE);
468 }
469
470 #[test]
471 fn test_adjust_lr_fn_match_rms_adamw() {
472 let method = AdjustLrFn::MatchRmsAdamW;
473
474 let ratio = method.adjustment_ratio(&[1024, 512]);
476 let expected = 0.2 * 1024.0f64.sqrt();
477 assert!((ratio - expected).abs() < TOLERANCE);
478
479 let ratio = method.adjustment_ratio(&[512, 512]);
481 let expected = 0.2 * 512.0f64.sqrt();
482 assert!((ratio - expected).abs() < TOLERANCE);
483 }
484
485 #[test]
486 #[should_panic(expected = "Newton-Schulz iteration requires 2D tensors, got 1D")]
487 fn test_1d_tensor_panics() {
488 let device = Default::default();
489 let config = MuonConfig::new();
490 let optim: Muon<TestBackend> = Muon {
491 momentum: Momentum::new(&config.momentum),
492 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
493 weight_decay_penalty: None,
494 epsilon: config.epsilon,
495 adjust_lr_fn: config.adjust_lr_fn,
496 };
497
498 let tensor_1d = Tensor::<TestBackend, 1>::zeros([512], &device);
499 let grad_1d = Tensor::<TestBackend, 1>::ones([512], &device);
500
501 let _ = optim.step(0.01, tensor_1d, grad_1d, None);
502 }
503
504 #[test]
505 fn test_muon_optimizer_save_load_state() {
506 let device = Default::default();
507 let linear = LinearConfig::new(6, 6)
509 .with_bias(false) .init::<TestAutodiffBackend>(&device);
511
512 let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);
513
514 let mut optimizer =
515 MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
516 let grads = linear.forward(x).backward();
517 let grads = GradientsParams::from_grads(grads, &linear);
518 let _linear = optimizer.step(0.01, linear, grads);
519
520 let state_before = optimizer.to_record();
521 let state_before_copy = optimizer.to_record();
522
523 let optimizer_new =
524 MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
525 let optimizer_loaded = optimizer_new.load_record(state_before_copy);
526 let state_after = optimizer_loaded.to_record();
527
528 assert_eq!(state_before.len(), state_after.len());
529 }
530
531 #[test]
532 fn test_muon_with_weight_decay() {
533 let device = Default::default();
534 let linear = given_linear_layer_no_bias(TensorData::from([
536 [1.0, 1.0, 1.0, 1.0],
537 [1.0, 1.0, 1.0, 1.0],
538 [1.0, 1.0, 1.0, 1.0],
539 [1.0, 1.0, 1.0, 1.0],
540 ]));
541
542 let x = Tensor::<TestAutodiffBackend, 2>::from_floats(
543 [[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],
544 &device,
545 )
546 .require_grad();
547
548 let mut optimizer = MuonConfig::new()
549 .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
550 .init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();
551
552 let grads = linear.forward(x).backward();
553 let grads = GradientsParams::from_grads(grads, &linear);
554 let linear = optimizer.step(0.01, linear, grads);
555
556 let state = linear.into_record();
557 let weight = state.weight.to_data();
558
559 for val in weight.as_slice::<f32>().unwrap() {
560 assert!(
561 *val < 1.0,
562 "Weight should be reduced by weight decay, got {}",
563 val
564 );
565 }
566 }
567
568 #[test]
569 fn test_newton_schulz_orthogonalization() {
570 let device = Default::default();
571 let matrix = Tensor::<TestBackend, 2>::from_floats([[1.0, 0.5], [0.5, 1.0]], &device);
572
573 let config = MuonConfig::new();
574 let muon: Muon<TestBackend> = Muon {
575 momentum: Momentum::new(&config.momentum),
576 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
577 weight_decay_penalty: None,
578 epsilon: config.epsilon,
579 adjust_lr_fn: config.adjust_lr_fn,
580 };
581
582 let orthogonalized = muon.zeropower_via_newtonschulz(matrix);
583 let o_t = orthogonalized.clone().transpose();
584 let product = orthogonalized.matmul(o_t);
585
586 let data = product.into_data();
587 let values = data.as_slice::<f32>().unwrap();
588
589 assert!(
590 (values[0] - 1.0).abs() < 0.1,
591 "Product[0,0] should be ~1.0, got {}",
592 values[0]
593 );
594 assert!(
595 (values[3] - 1.0).abs() < 0.1,
596 "Product[1,1] should be ~1.0, got {}",
597 values[3]
598 );
599 }
600
601 #[test]
602 fn test_tall_matrix_transpose() {
603 let device = Default::default();
606
607 let tall_matrix = Tensor::<TestBackend, 2>::from_floats(
609 [
610 [1.0, 0.5, 0.3, 0.2],
611 [0.5, 1.0, 0.4, 0.1],
612 [0.3, 0.4, 1.0, 0.5],
613 [0.2, 0.1, 0.5, 1.0],
614 [0.1, 0.2, 0.3, 0.4],
615 [0.4, 0.3, 0.2, 0.1],
616 [0.2, 0.4, 0.1, 0.3],
617 [0.3, 0.1, 0.4, 0.2],
618 ],
619 &device,
620 );
621
622 let config = MuonConfig::new();
623 let muon: Muon<TestBackend> = Muon {
624 momentum: Momentum::new(&config.momentum),
625 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
626 weight_decay_penalty: None,
627 epsilon: config.epsilon,
628 adjust_lr_fn: config.adjust_lr_fn,
629 };
630
631 let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone());
633
634 let original_shape = tall_matrix.shape();
636 let result_shape = orthogonalized.shape();
637 assert_eq!(
638 original_shape.dims::<2>(),
639 result_shape.dims::<2>(),
640 "Shape should be preserved: [8, 4]"
641 );
642
643 let original_data = tall_matrix.into_data();
645 let result_data = orthogonalized.into_data();
646 assert_ne!(
647 original_data.as_slice::<f32>().unwrap(),
648 result_data.as_slice::<f32>().unwrap(),
649 "Orthogonalized matrix should differ from input"
650 );
651
652 let wide_matrix = Tensor::<TestBackend, 2>::from_floats(
654 [
655 [1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3],
656 [0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1],
657 [0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4],
658 [0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2],
659 ],
660 &device,
661 );
662
663 let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone());
664
665 let wide_original_shape = wide_matrix.shape();
667 let wide_result_shape = orthogonalized_wide.shape();
668 assert_eq!(
669 wide_original_shape.dims::<2>(),
670 wide_result_shape.dims::<2>(),
671 "Wide matrix shape should be preserved: [4, 8]"
672 );
673 }
674
675 #[test]
676 fn test_zero_gradient() {
677 let device = Default::default();
679
680 let tensor = Tensor::<TestBackend, 2>::from_floats(
681 [
682 [1.0, 0.5, 0.3, 0.2],
683 [0.5, 1.0, 0.4, 0.1],
684 [0.3, 0.4, 1.0, 0.5],
685 [0.2, 0.1, 0.5, 1.0],
686 ],
687 &device,
688 );
689
690 let zero_grad = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
692
693 let config = MuonConfig::new();
694 let muon: Muon<TestBackend> = Muon {
695 momentum: Momentum::new(&config.momentum),
696 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
697 weight_decay_penalty: None,
698 epsilon: config.epsilon,
699 adjust_lr_fn: config.adjust_lr_fn,
700 };
701
702 let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None);
704
705 assert!(state.is_some());
707
708 let original_data = tensor.into_data();
710 let updated_data = updated_tensor.clone().into_data();
711
712 let original_vals = original_data.as_slice::<f32>().unwrap();
713 let updated_vals = updated_data.as_slice::<f32>().unwrap();
714
715 for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) {
716 assert!(
717 (orig - upd).abs() < 1e-6,
718 "With zero gradient, tensor should remain unchanged (or very close)"
719 );
720 }
721
722 for val in updated_vals {
724 assert!(
725 !val.is_nan(),
726 "Result should not contain NaN values with zero gradient"
727 );
728 }
729
730 let muon_with_decay: Muon<TestBackend> = Muon {
732 momentum: Momentum::new(&config.momentum),
733 ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
734 weight_decay_penalty: Some(0.01),
735 epsilon: config.epsilon,
736 adjust_lr_fn: config.adjust_lr_fn,
737 };
738
739 let tensor2 = Tensor::<TestBackend, 2>::from_floats(
740 [
741 [1.0, 0.5, 0.3, 0.2],
742 [0.5, 1.0, 0.4, 0.1],
743 [0.3, 0.4, 1.0, 0.5],
744 [0.2, 0.1, 0.5, 1.0],
745 ],
746 &device,
747 );
748 let zero_grad2 = Tensor::<TestBackend, 2>::zeros([4, 4], &device);
749
750 let (updated_tensor_decay, _) =
751 muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None);
752
753 let updated_decay_data = updated_tensor_decay.into_data();
755 let updated_decay_vals = updated_decay_data.as_slice::<f32>().unwrap();
756
757 for val in updated_decay_vals {
758 assert!(
759 !val.is_nan(),
760 "Result should not contain NaN with zero gradient and weight decay"
761 );
762 }
763
764 let original_vals2 = tensor2.into_data().as_slice::<f32>().unwrap().to_vec();
766 for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) {
767 if orig.abs() > 1e-6 {
768 assert!(
770 upd.abs() < orig.abs(),
771 "Weight decay should reduce magnitude: original={}, updated={}",
772 orig,
773 upd
774 );
775 }
776 }
777 }
778}