use burn::prelude::*;
use burn::nn::{Linear, RmsNorm, RmsNormConfig};
use crate::model::linear_zeros;
#[derive(Module, Debug)]
pub struct RMSNorm<B: Backend> {
pub inner: RmsNorm<B>,
}
impl<B: Backend> RMSNorm<B> {
pub fn new(dim: usize, eps: f64, device: &B::Device) -> Self {
Self {
inner: RmsNormConfig::new(dim).with_epsilon(eps).init(device),
}
}
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
self.inner.forward(x)
}
}
#[derive(Module, Debug)]
pub struct AdaRMSNorm<B: Backend> {
pub weight: Linear<B>,
pub eps: f64,
}
impl<B: Backend> AdaRMSNorm<B> {
pub fn new(emb_dim: usize, dim: usize, eps: f64, device: &B::Device) -> Self {
Self {
weight: linear_zeros(emb_dim, dim, true, device),
eps,
}
}
pub fn forward(&self, x: Tensor<B, 3>, c: Tensor<B, 3>) -> Tensor<B, 3> {
let eps = self.eps as f32;
let rms = (x.clone().powf_scalar(2.0f32).mean_dim(2) + eps).sqrt();
let normed = x / rms;
normed * self.weight.forward(c)
}
}