use crate::autograd::Variable;
use crate::tensor::{Device, DType, Result, Tensor, TensorOptions};
use super::parameter::Parameter;
use super::Module;
pub struct RMSNorm {
pub weight: Parameter,
eps: f64,
}
impl RMSNorm {
pub fn new(size: i64) -> Result<Self> {
Self::on_device(size, Device::CPU)
}
pub fn on_device(size: i64, device: Device) -> Result<Self> {
let opts = TensorOptions { dtype: DType::Float32, device };
let weight = Variable::new(Tensor::ones(&[size], opts)?, true);
Ok(RMSNorm {
weight: Parameter {
variable: weight,
name: "weight".into(),
},
eps: 1e-5,
})
}
pub fn eps(mut self, eps: f64) -> Self {
self.eps = eps;
self
}
}
impl Module for RMSNorm {
fn name(&self) -> &str { "rmsnorm" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let x2 = input.pow_scalar(2.0)?;
let mean_x2 = x2.mean_dim(-1, true)?;
let rms = mean_x2.add_scalar(self.eps)?.sqrt()?;
let normalized = input.div(&rms)?;
normalized.mul(&self.weight.variable)
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone()]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::test_device;
#[test]
fn test_rmsnorm_shape() {
let device = test_device();
let norm = RMSNorm::on_device(4, device).unwrap();
let x = Variable::new(
Tensor::randn(&[2, 3, 4], crate::tensor::test_opts()).unwrap(),
false,
);
let y = norm.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 3, 4]);
}
#[test]
fn test_rmsnorm_unit_rms() {
let device = test_device();
let norm = RMSNorm::on_device(4, device).unwrap();
let x = Variable::new(
Tensor::ones(&[1, 4], crate::tensor::test_opts()).unwrap(),
false,
);
let y = norm.forward(&x).unwrap().data().to_f32_vec().unwrap();
for v in &y {
assert!((v - 1.0).abs() < 1e-3);
}
}
#[test]
fn test_rmsnorm_gradient() {
let device = test_device();
let norm = RMSNorm::on_device(4, device).unwrap();
let x = Variable::new(
Tensor::randn(&[2, 4], crate::tensor::test_opts()).unwrap(),
true,
);
let y = norm.forward(&x).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
let grad = x.grad().unwrap();
assert_eq!(grad.shape(), vec![2, 4]);
let w_grad = norm.weight.variable.grad().unwrap();
assert_eq!(w_grad.shape(), vec![4]);
}
#[test]
fn test_rmsnorm_parameters() {
let norm = RMSNorm::new(8).unwrap();
let params = norm.parameters();
assert_eq!(params.len(), 1);
assert_eq!(params[0].name, "weight");
}
}