pub fn rms_norm_f32(buf: &mut [f32], eps: f32) {
if buf.is_empty() {
return;
}
let n = buf.len() as f32;
let sum_sq: f32 = buf.iter().map(|x| x * x).sum();
let mean_sq = sum_sq / n;
let inv_std = (mean_sq + eps).sqrt().recip();
for x in buf.iter_mut() {
*x *= inv_std;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn close_enough(got: f32, want: f32) -> bool {
(got - want).abs() < 1e-5
}
#[test]
fn rms_norm_constant() {
let mut buf = [1.0; 16];
rms_norm_f32(&mut buf, 0.0);
for v in buf {
assert!(close_enough(v, 1.0), "got {v}, want 1.0");
}
}
#[test]
fn rms_norm_one_two_three_four() {
let mut buf = [1.0_f32, 2.0, 3.0, 4.0];
rms_norm_f32(&mut buf, 0.0);
let inv = (7.5_f32).sqrt().recip();
for (i, v) in buf.iter().enumerate() {
let want = (i + 1) as f32 * inv;
assert!(close_enough(*v, want), "i={i}: got {v}, want {want}");
}
}
#[test]
fn rms_norm_eps_added_under_root() {
let mut buf = [0.0_f32; 4];
rms_norm_f32(&mut buf, 1e-5);
for v in buf {
assert!(v.is_finite());
assert_eq!(v, 0.0);
}
}
#[test]
fn rms_norm_empty() {
let mut buf: [f32; 0] = [];
rms_norm_f32(&mut buf, 1e-5);
}
}