use crate::error::{LmError, LmResult};
fn mean(x: &[f32]) -> f32 {
if x.is_empty() {
return 0.0;
}
let sum: f32 = x.iter().sum();
sum / x.len() as f32
}
fn variance(x: &[f32], mu: f32) -> f32 {
if x.is_empty() {
return 0.0;
}
x.iter().map(|&v| (v - mu) * (v - mu)).sum::<f32>() / x.len() as f32
}
#[derive(Debug, Clone)]
pub struct RmsNorm {
pub dim: usize,
pub eps: f32,
pub weight: Vec<f32>,
}
impl RmsNorm {
pub fn new(dim: usize, eps: f32) -> LmResult<Self> {
if dim == 0 {
return Err(LmError::InvalidConfig {
msg: "RmsNorm dim must be > 0".into(),
});
}
Ok(Self {
dim,
eps,
weight: vec![1.0_f32; dim],
})
}
pub fn from_weight(weight: Vec<f32>, eps: f32) -> LmResult<Self> {
let dim = weight.len();
if dim == 0 {
return Err(LmError::InvalidConfig {
msg: "RmsNorm weight must be non-empty".into(),
});
}
Ok(Self { dim, eps, weight })
}
pub fn forward(&self, x: &[f32], n_tokens: usize) -> LmResult<Vec<f32>> {
if x.len() != n_tokens * self.dim {
return Err(LmError::DimensionMismatch {
expected: n_tokens * self.dim,
got: x.len(),
});
}
let mut out = vec![0.0_f32; x.len()];
for t in 0..n_tokens {
let row = &x[t * self.dim..(t + 1) * self.dim];
let mean_sq: f32 = row.iter().map(|&v| v * v).sum::<f32>() / self.dim as f32;
let inv_rms = 1.0 / (mean_sq + self.eps).sqrt();
let out_row = &mut out[t * self.dim..(t + 1) * self.dim];
for (i, (&xi, &wi)) in row.iter().zip(self.weight.iter()).enumerate() {
out_row[i] = xi * inv_rms * wi;
}
}
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct LayerNorm {
pub dim: usize,
pub eps: f32,
pub weight: Vec<f32>,
pub bias: Vec<f32>,
}
impl LayerNorm {
pub fn new(dim: usize, eps: f32) -> LmResult<Self> {
if dim == 0 {
return Err(LmError::InvalidConfig {
msg: "LayerNorm dim must be > 0".into(),
});
}
Ok(Self {
dim,
eps,
weight: vec![1.0_f32; dim],
bias: vec![0.0_f32; dim],
})
}
pub fn from_weights(weight: Vec<f32>, bias: Vec<f32>, eps: f32) -> LmResult<Self> {
let dim = weight.len();
if dim == 0 {
return Err(LmError::InvalidConfig {
msg: "LayerNorm weight must be non-empty".into(),
});
}
if bias.len() != dim {
return Err(LmError::DimensionMismatch {
expected: dim,
got: bias.len(),
});
}
Ok(Self {
dim,
eps,
weight,
bias,
})
}
pub fn forward(&self, x: &[f32], n_tokens: usize) -> LmResult<Vec<f32>> {
if x.len() != n_tokens * self.dim {
return Err(LmError::DimensionMismatch {
expected: n_tokens * self.dim,
got: x.len(),
});
}
let mut out = vec![0.0_f32; x.len()];
for t in 0..n_tokens {
let row = &x[t * self.dim..(t + 1) * self.dim];
let mu = mean(row);
let var = variance(row, mu);
let inv_std = 1.0 / (var + self.eps).sqrt();
let out_row = &mut out[t * self.dim..(t + 1) * self.dim];
for (i, (&xi, (&wi, &bi))) in row
.iter()
.zip(self.weight.iter().zip(self.bias.iter()))
.enumerate()
{
out_row[i] = (xi - mu) * inv_std * wi + bi;
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rms_norm_ones_weight_identity_direction() {
let n = RmsNorm::new(4, 1e-8).expect("dim=4 RmsNorm should be valid");
let x = vec![3.0_f32, 4.0, 0.0, 0.0]; let out = n
.forward(&x, 1)
.expect("1-token RmsNorm forward with matching dim should succeed");
assert!((out[0] - 1.2).abs() < 1e-5, "out[0]={}", out[0]);
assert!((out[1] - 1.6).abs() < 1e-5, "out[1]={}", out[1]);
assert!(out[2].abs() < 1e-5);
}
#[test]
fn rms_norm_scale_weight() {
let mut n = RmsNorm::new(2, 1e-8).expect("dim=2 RmsNorm should be valid");
n.weight = vec![2.0, 0.5];
let x = vec![1.0_f32, 1.0]; let out = n
.forward(&x, 1)
.expect("1-token RmsNorm forward with scale weight should succeed");
assert!((out[0] - 2.0).abs() < 1e-5);
assert!((out[1] - 0.5).abs() < 1e-5);
}
#[test]
fn rms_norm_batch_tokens() {
let n = RmsNorm::new(2, 1e-8).expect("dim=2 RmsNorm for batch test should be valid");
let x = vec![1.0_f32, 1.0, 2.0, 2.0];
let out = n
.forward(&x, 2)
.expect("2-token RmsNorm batch forward should succeed");
for &v in &out {
assert!((v - 1.0).abs() < 1e-5, "v={v}");
}
}
#[test]
fn rms_norm_dim_mismatch_error() {
let n = RmsNorm::new(4, 1e-8).expect("dim=4 RmsNorm should be valid");
let err = n.forward(&[1.0, 2.0], 1).unwrap_err();
assert!(matches!(err, LmError::DimensionMismatch { .. }));
}
#[test]
fn rms_norm_zero_dim_error() {
assert!(RmsNorm::new(0, 1e-5).is_err());
}
#[test]
fn rms_norm_from_weight() {
let n = RmsNorm::from_weight(vec![1.0, 2.0, 3.0], 1e-8)
.expect("non-empty weight vector should produce valid RmsNorm");
assert_eq!(n.dim, 3);
}
#[test]
fn layer_norm_zero_centered_unit_variance() {
let ln = LayerNorm::new(4, 1e-8).expect("dim=4 LayerNorm should be valid");
let x = vec![1.0_f32, 2.0, 3.0, 4.0];
let out = ln
.forward(&x, 1)
.expect("1-token LayerNorm forward should succeed");
let m = out.iter().sum::<f32>() / 4.0;
let v = out.iter().map(|&v| (v - m) * (v - m)).sum::<f32>() / 4.0;
assert!(m.abs() < 1e-5, "mean={m}");
assert!((v - 1.0).abs() < 1e-4, "var={v}");
}
#[test]
fn layer_norm_weight_and_bias() {
let ln = LayerNorm::from_weights(vec![2.0, 2.0], vec![1.0, 1.0], 1e-8)
.expect("matching weight and bias length=2 should be valid");
let x = vec![0.0_f32, 0.0]; let out = ln
.forward(&x, 1)
.expect("1-token LayerNorm with custom weight/bias should succeed");
assert!((out[0] - 1.0).abs() < 1e-3, "out[0]={}", out[0]);
}
#[test]
fn layer_norm_batch_tokens() {
let ln = LayerNorm::new(4, 1e-8).expect("dim=4 LayerNorm for batch test should be valid");
let x = vec![1.0_f32, 2.0, 3.0, 4.0, -1.0, -2.0, -3.0, -4.0];
let out = ln
.forward(&x, 2)
.expect("2-token LayerNorm batch forward should succeed");
assert_eq!(out.len(), 8);
let m1: f32 = out[..4].iter().sum::<f32>() / 4.0;
let m2: f32 = out[4..].iter().sum::<f32>() / 4.0;
assert!(m1.abs() < 1e-5, "token 0 mean={m1}");
assert!(m2.abs() < 1e-5, "token 1 mean={m2}");
}
#[test]
fn layer_norm_dim_mismatch_error() {
let ln = LayerNorm::new(4, 1e-8).expect("dim=4 LayerNorm should be valid");
let err = ln.forward(&[1.0, 2.0], 1).unwrap_err();
assert!(matches!(err, LmError::DimensionMismatch { .. }));
}
#[test]
fn layer_norm_weight_bias_dim_mismatch_error() {
let err = LayerNorm::from_weights(vec![1.0, 2.0], vec![0.0], 1e-5).unwrap_err();
assert!(matches!(err, LmError::DimensionMismatch { .. }));
}
#[test]
fn layer_norm_zero_dim_error() {
assert!(LayerNorm::new(0, 1e-5).is_err());
}
}