use burn::config::Config;
use burn::module::{Module, Param};
use burn::prelude::*;
#[derive(Config, Debug)]
pub struct RmsNormConfig {
pub d_model: usize,
#[config(default = 1e-6)]
pub eps: f64,
}
impl RmsNormConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
RmsNorm {
gamma: Param::from_tensor(Tensor::ones([self.d_model], device)),
eps: self.eps,
}
}
}
#[derive(Module, Debug)]
pub struct RmsNorm<B: Backend> {
gamma: Param<Tensor<B, 1>>,
eps: f64,
}
impl<B: Backend> RmsNorm<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let variance = x.clone().powf_scalar(2.0).mean_dim(2); let rms = variance.add_scalar(self.eps).sqrt(); let normed = x / rms; normed * self.gamma.val().unsqueeze_dim::<2>(0).unsqueeze_dim::<3>(0)
}
pub fn forward_4d(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let variance = x.clone().powf_scalar(2.0).mean_dim(3); let rms = variance.add_scalar(self.eps).sqrt(); let normed = x / rms; normed
* self
.gamma
.val()
.unsqueeze_dim::<2>(0)
.unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::tensor::Distribution;
type TestBackend = NdArray;
#[test]
fn test_rmsnorm_output_shape_3d() {
let device = Default::default();
let norm = RmsNormConfig::new(64).init::<TestBackend>(&device);
let x = Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device);
let out = norm.forward(x);
assert_eq!(out.dims(), [2, 16, 64]);
}
#[test]
fn test_rmsnorm_output_shape_4d() {
let device = Default::default();
let norm = RmsNormConfig::new(64).init::<TestBackend>(&device);
let x = Tensor::random([3, 2, 16, 64], Distribution::Normal(0.0, 1.0), &device);
let out = norm.forward_4d(x);
assert_eq!(out.dims(), [3, 2, 16, 64]);
}
#[test]
fn test_rmsnorm_normalizes_magnitude() {
let device = Default::default();
let norm = RmsNormConfig::new(4).init::<TestBackend>(&device);
let x = Tensor::<TestBackend, 3>::from_floats([[[10.0, 20.0, 30.0, 40.0]]], &device);
let out = norm.forward(x);
let max_val: f32 = out.abs().max().into_scalar();
assert!(
max_val < 2.0,
"RMSNorm should reduce magnitude, got {max_val}"
);
}
}