use zyx::{DType, Tensor, ZyxError};
use zyx_derive::Module;
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct RMSNorm {
pub scale: Tensor,
pub eps: f64,
}
impl RMSNorm {
pub fn new(dim: u64, dtype: DType) -> RMSNorm {
RMSNorm {
scale: Tensor::ones(dim, dtype),
eps: 1e-6,
}
}
pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
let x = x.into();
let dtype = x.dtype();
let x_normed =
&x * (x.pow(2)?.mean_keepdim([-1])? + Tensor::from(self.eps).cast(dtype)).rsqrt();
Ok(x_normed * &self.scale)
}
}