use burn::prelude::*;
use burn::module::{Param, ParamId};
#[derive(Debug)]
pub struct RmsNorm<B: Backend> {
pub weight: Param<Tensor<B, 1>>,
pub eps: f64,
}
impl<B: Backend> RmsNorm<B> {
pub fn new(size: usize, eps: f64, device: &B::Device) -> Self {
let weight = Tensor::ones([size], device);
Self {
weight: Param::initialized(ParamId::new(), weight),
eps,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let variance = x.clone().powf_scalar(2.0).mean_dim(2);
let scale = (variance + self.eps).sqrt().recip();
let normed = x * scale;
normed * self.weight.val().clone().unsqueeze::<3>()
}
}