use candle_core::{DType, Device, Tensor};
use crate::error::Result;
pub struct RmsNorm {
weight: Tensor,
eps: f64,
}
impl RmsNorm {
pub fn new(hidden_size: usize, eps: f64, device: &Device) -> Result<Self> {
let weight = Tensor::ones((hidden_size,), DType::F32, device)?;
Ok(Self { weight, eps })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let device = x.device();
if device.is_cuda() {
self.forward_cuda(x)
} else {
self.forward_cpu(x)
}
}
fn forward_cpu(&self, x: &Tensor) -> Result<Tensor> {
let x_sq = x.sqr()?;
let mean_sq = x_sq.mean_keepdim(x.rank() - 1)?;
let rms = (mean_sq + self.eps)?.sqrt()?;
let normalized = x.broadcast_div(&rms)?;
let output = normalized.broadcast_mul(&self.weight)?;
Ok(output)
}
fn forward_cuda(&self, x: &Tensor) -> Result<Tensor> {
tracing::debug!("Using CUDA RMSNorm for input shape {:?}", x.shape());
let x_sq = x.sqr()?;
let shape = x.shape().dims();
let last_dim = shape[shape.len() - 1];
let batch_size = shape[..shape.len() - 1].iter().product::<usize>();
let x_sq_flat = x_sq.reshape((batch_size, last_dim))?;
let sum_sq = x_sq_flat.sum_keepdim(1)?;
let scale = Tensor::new(&[(last_dim as f32).recip()], x.device())?;
let mean_sq = (sum_sq.broadcast_mul(&scale))?;
let eps_tensor = Tensor::new(&[self.eps as f32], x.device())?;
let rms = (mean_sq.broadcast_add(&eps_tensor))?.sqrt()?;
let x_flat = x.reshape((batch_size, last_dim))?;
let normalized = x_flat.broadcast_div(&rms)?;
let normalized_orig = normalized.reshape(shape)?;
let output = normalized_orig.broadcast_mul(&self.weight)?;
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rmsnorm_creation() {
let device = Device::Cpu;
let norm = RmsNorm::new(768, 1e-5, &device);
assert!(norm.is_ok());
}
#[test]
fn test_rmsnorm_forward() {
let device = Device::Cpu;
let norm = RmsNorm::new(768, 1e-5, &device).unwrap();
let input = Tensor::randn(0.0f32, 1.0, (2, 10, 768), &device).unwrap();
let output = norm.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[2, 10, 768]);
}
#[test]
fn test_rmsnorm_normalizes_values() {
let device = Device::Cpu;
let norm = RmsNorm::new(64, 1e-5, &device).unwrap();
let input = Tensor::randn(0.0f32, 5.0, (1, 1, 64), &device).unwrap();
let output = norm.forward(&input).unwrap();
let output_sq = output.sqr().unwrap();
let mean_sq = output_sq.mean_all().unwrap().to_scalar::<f32>().unwrap();
assert!(
(mean_sq.sqrt() - 1.0).abs() < 0.5,
"RMS should be approximately 1, got {}",
mean_sq.sqrt()
);
}
#[test]
fn test_rmsnorm_numerical_stability() {
let device = Device::Cpu;
let norm = RmsNorm::new(128, 1e-5, &device).unwrap();
let small_input = Tensor::full(1e-6f32, (1, 1, 128), &device).unwrap();
let output = norm.forward(&small_input);
assert!(output.is_ok());
let large_input = Tensor::randn(0.0f32, 100.0, (1, 1, 128), &device).unwrap();
let output = norm.forward(&large_input).unwrap();
let values: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
for v in values {
assert!(!v.is_nan(), "Output contains NaN");
assert!(!v.is_infinite(), "Output contains Inf");
}
}
}