use candle_core::{Result, Tensor, DType};
pub fn rms_norm(hidden_states: &Tensor, variance_epsilon: f64) -> Result<Tensor> {
let input_dtype = hidden_states.dtype();
let hidden_states = if input_dtype != DType::F32 {
hidden_states.to_dtype(DType::F32)?
} else {
hidden_states.clone()
};
let variance = hidden_states.sqr()?.mean_keepdim(candle_core::D::Minus1)?;
let normalized = hidden_states.broadcast_div(
&(variance + variance_epsilon)?.sqrt()?
)?;
if input_dtype != DType::F32 {
normalized.to_dtype(input_dtype)
} else {
Ok(normalized)
}
}
pub struct RMSNorm {
weight: Tensor,
eps: f64,
}
impl RMSNorm {
pub fn new(hidden_size: usize, eps: f64, vb: candle_nn::VarBuilder) -> Result<Self> {
let weight = vb.get((hidden_size,), "weight")?;
Ok(Self { weight, eps })
}
pub fn new_no_weight(hidden_size: usize, eps: f64, device: &candle_core::Device) -> Result<Self> {
let weight = Tensor::ones((hidden_size,), DType::F32, device)?;
Ok(Self { weight, eps })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let normalized = rms_norm(x, self.eps)?;
normalized.broadcast_mul(&self.weight)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{Device, DType};
#[test]
fn test_rms_norm_basic() -> Result<()> {
let device = Device::Cpu;
let x = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?.reshape((1, 4))?;
let normalized = rms_norm(&x, 1e-6)?;
let rms = normalized.sqr()?.mean_all()?.to_scalar::<f32>()?;
assert!((rms - 1.0).abs() < 0.1, "RMS should be close to 1.0, got {}", rms);
Ok(())
}
#[test]
fn test_rms_norm_preserves_shape() -> Result<()> {
let device = Device::Cpu;
let x = Tensor::randn(0f32, 1.0, (2, 8, 64), &device)?;
let normalized = rms_norm(&x, 1e-6)?;
assert_eq!(x.dims(), normalized.dims());
Ok(())
}
}