use super::NormError;
pub fn rms_norm_in_place(
x: &mut [f32],
gamma: &[f32],
eps: f32,
) -> Result<(), NormError> {
if x.is_empty() {
return Err(NormError::Empty);
}
if gamma.len() != x.len() {
return Err(NormError::ShapeMismatch);
}
if !eps.is_finite() || eps <= 0.0 {
return Err(NormError::InvalidEps);
}
let mut ms = 0.0f32;
for v in x.iter() {
ms += *v * *v;
}
ms /= x.len() as f32;
let inv_rms = 1.0 / crate::math::sqrtf(ms + eps);
for i in 0..x.len() {
x[i] = x[i] * inv_rms * gamma[i];
}
Ok(())
}
pub fn rms_norm(
input: &[f32],
gamma: &[f32],
eps: f32,
out: &mut [f32],
) -> Result<(), NormError> {
if out.len() != input.len() {
return Err(NormError::ShapeMismatch);
}
out.copy_from_slice(input);
rms_norm_in_place(out, gamma, eps)
}