use crate::autograd::scale;
use crate::Tensor;
use std::collections::HashMap;
pub struct RMSNorm {
pub weight: Tensor,
eps: f32,
}
impl RMSNorm {
pub fn new(hidden_size: usize, eps: f32) -> Self {
Self { weight: Tensor::ones(hidden_size, true), eps }
}
pub fn from_params(
params: &HashMap<String, Tensor>,
prefix: &str,
eps: f32,
hidden_size: usize,
) -> Option<Self> {
let weight = params.get(&format!("{prefix}.weight"))?.clone();
if weight.len() != hidden_size {
eprintln!(
"[PMAT-332] {prefix}.weight: length mismatch — got {}, expected {hidden_size}",
weight.len()
);
return None;
}
Some(Self { weight, eps })
}
pub fn forward(&self, x: &Tensor) -> Tensor {
let n = x.len() as f32;
let sq_sum: f32 = x.data().iter().map(|v| v * v).sum();
let rms = (sq_sum / n + self.eps).sqrt();
let normalized = scale(x, 1.0 / rms);
crate::autograd::mul(&normalized, &self.weight)
}
pub fn forward_batched(&self, x: &Tensor, seq_len: usize, hidden_size: usize) -> Tensor {
let mut output = vec![0.0; seq_len * hidden_size];
let mut rms_values = Vec::with_capacity(seq_len);
let x_data = x.data();
let x_slice = x_data.as_slice().expect("norm input must be contiguous");
let w_data = self.weight.data();
let w_slice = w_data.as_slice().expect("norm weight must be contiguous");
for s in 0..seq_len {
let start = s * hidden_size;
let end = start + hidden_size;
let slice = &x_slice[start..end];
let sq_sum: f32 = slice.iter().map(|v| v * v).sum();
let rms = (sq_sum / hidden_size as f32 + self.eps).sqrt();
rms_values.push(rms);
for (i, &val) in slice.iter().enumerate() {
output[start + i] = (val / rms) * w_slice[i];
}
}
let requires_grad = x.requires_grad() || self.weight.requires_grad();
let mut result = Tensor::from_vec(output, requires_grad);
if requires_grad {
use crate::autograd::BackwardOp;
use ndarray::Array1;
use std::cell::RefCell;
use std::rc::Rc;
struct RMSNormBatchedBackward {
x: Tensor,
weight: Tensor,
rms_values: Vec<f32>,
seq_len: usize,
hidden_size: usize,
result_grad: Rc<RefCell<Option<Array1<f32>>>>,
}
impl BackwardOp for RMSNormBatchedBackward {
fn backward(&self) {
if let Some(grad_output) = self.result_grad.borrow().as_ref() {
let h = self.hidden_size;
let x_data = self.x.data();
let x_sl = x_data.as_slice().expect("x contiguous");
let w_data = self.weight.data();
let w_sl = w_data.as_slice().expect("weight contiguous");
let go = grad_output.as_slice().expect("grad contiguous");
if self.x.requires_grad() {
let mut grad_x = vec![0.0_f32; self.seq_len * h];
let n = h as f32;
for s in 0..self.seq_len {
let off = s * h;
let rms = self.rms_values[s];
let mut dot = 0.0_f32;
for i in 0..h {
dot += go[off + i] * w_sl[i] * x_sl[off + i];
}
let c = dot / (n * rms * rms);
for j in 0..h {
grad_x[off + j] =
(go[off + j] * w_sl[j] - x_sl[off + j] * c) / rms;
}
}
self.x.accumulate_grad(Array1::from(grad_x));
}
if self.weight.requires_grad() {
let mut grad_w = vec![0.0_f32; h];
for s in 0..self.seq_len {
let off = s * h;
let rms = self.rms_values[s];
for i in 0..h {
grad_w[i] += go[off + i] * x_sl[off + i] / rms;
}
}
self.weight.accumulate_grad(Array1::from(grad_w));
}
if let Some(op) = self.x.backward_op() {
op.backward();
}
if let Some(op) = self.weight.backward_op() {
op.backward();
}
}
}
}
let backward_op = Rc::new(RMSNormBatchedBackward {
x: x.clone(),
weight: self.weight.clone(),
rms_values,
seq_len,
hidden_size,
result_grad: result.grad_cell(),
});
result.set_backward_op(backward_op);
}
result
}
}
pub struct LayerNorm {
pub weight: Tensor,
pub bias: Tensor,
eps: f32,
hidden_size: usize,
}
impl LayerNorm {
pub fn new(hidden_size: usize, eps: f32) -> Self {
Self {
weight: Tensor::ones(hidden_size, true),
bias: Tensor::from_vec(vec![0.0; hidden_size], true),
eps,
hidden_size,
}
}
pub fn from_params(
params: &HashMap<String, Tensor>,
prefix: &str,
eps: f32,
hidden_size: usize,
) -> Option<Self> {
let weight = params.get(&format!("{prefix}.weight"))?.clone();
let bias = params.get(&format!("{prefix}.bias"))?.clone();
if weight.len() != hidden_size || bias.len() != hidden_size {
eprintln!(
"[ENC-005] {prefix}: shape mismatch — weight={}, bias={}, expected {hidden_size}",
weight.len(),
bias.len()
);
return None;
}
Some(Self { weight, bias, eps, hidden_size })
}
pub fn forward_batched(&self, x: &Tensor, seq_len: usize, hidden_size: usize) -> Tensor {
let mut output = vec![0.0_f32; seq_len * hidden_size];
let x_data = x.data();
let x_slice = x_data.as_slice().expect("input contiguous");
let w_data = self.weight.data();
let w_slice = w_data.as_slice().expect("weight contiguous");
let b_data = self.bias.data();
let b_slice = b_data.as_slice().expect("bias contiguous");
for s in 0..seq_len {
let start = s * hidden_size;
let end = start + hidden_size;
let row = &x_slice[start..end];
let mean: f32 = row.iter().sum::<f32>() / hidden_size as f32;
let var: f32 =
row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / hidden_size as f32;
let inv_std = 1.0 / (var + self.eps).sqrt();
for (i, &val) in row.iter().enumerate() {
output[start + i] = (val - mean) * inv_std * w_slice[i] + b_slice[i];
}
}
Tensor::from_vec(output, x.requires_grad() || self.weight.requires_grad())
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rms_norm_forward() {
let norm = RMSNorm::new(4, 1e-6);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true);
let output = norm.forward(&x);
assert_eq!(output.len(), 4);
let data = output.data();
assert!(data.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_rms_norm_batched() {
let norm = RMSNorm::new(4, 1e-6);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
let output = norm.forward_batched(&x, 2, 4);
assert_eq!(output.len(), 8);
}
#[test]
fn test_rms_norm_normalization_property() {
let norm = RMSNorm::new(4, 1e-6);
let x = Tensor::from_vec(vec![2.0, 2.0, 2.0, 2.0], true);
let output = norm.forward(&x);
let data = output.data();
for &val in data {
assert!((val - 1.0).abs() < 1e-5, "Expected ~1.0, got {val}");
}
}
#[test]
fn test_rms_norm_with_zeros() {
let norm = RMSNorm::new(4, 1e-6);
let x = Tensor::from_vec(vec![0.0, 0.0, 0.0, 0.0], true);
let output = norm.forward(&x);
let data = output.data();
assert!(data.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_rms_norm_weight_requires_grad() {
let norm = RMSNorm::new(4, 1e-6);
assert!(norm.weight.requires_grad());
}
#[test]
fn test_rms_norm_from_params() {
let mut params = HashMap::new();
params.insert("test.weight".to_string(), Tensor::from_vec(vec![1.0, 1.0, 1.0, 1.0], true));
let norm = RMSNorm::from_params(¶ms, "test", 1e-6, 4);
assert!(norm.is_some());
let norm = norm.expect("operation should succeed");
assert_eq!(norm.weight.len(), 4);
}
#[test]
fn test_rms_norm_from_params_missing() {
let params: HashMap<String, Tensor> = HashMap::new();
let norm = RMSNorm::from_params(¶ms, "missing", 1e-6, 4);
assert!(norm.is_none());
}
#[test]
fn test_rms_norm_backward_gradient_exists() {
let norm = RMSNorm::new(8, 1e-6);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
let mut output = norm.forward(&x);
let grad_out = ndarray::Array1::ones(8);
crate::autograd::backward(&mut output, Some(grad_out));
assert!(norm.weight.grad().is_some());
let grad = norm.weight.grad().expect("gradient should be available");
assert!(grad.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_rms_norm_batched_backward_gradient_exists() {
let norm = RMSNorm::new(4, 1e-6);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
let mut output = norm.forward_batched(&x, 2, 4);
let grad_out = ndarray::Array1::ones(8);
crate::autograd::backward(&mut output, Some(grad_out));
assert!(norm.weight.grad().is_some(), "ALB-038: norm weight must have gradient");
let wgrad = norm.weight.grad().expect("gradient available");
assert!(wgrad.iter().all(|&v| v.is_finite()), "Weight gradients must be finite");
assert!(wgrad.iter().any(|&v| v.abs() > 1e-10), "Weight gradients must be non-zero");
assert!(x.grad().is_some(), "ALB-038: input x must have gradient");
let xgrad = x.grad().expect("gradient available");
assert!(xgrad.iter().all(|&v| v.is_finite()), "Input gradients must be finite");
assert!(xgrad.iter().any(|&v| v.abs() > 1e-10), "Input gradients must be non-zero");
}
#[test]
fn test_rms_norm_batched_backward_weight_grad_matches() {
let hidden = 4;
let data = vec![1.0_f32, -2.0, 3.0, -0.5];
let norm1 = RMSNorm::new(hidden, 1e-6);
let x1 = Tensor::from_vec(data.clone(), true);
let mut out1 = norm1.forward(&x1);
crate::autograd::backward(&mut out1, Some(ndarray::Array1::ones(hidden)));
let wgrad1 = norm1.weight.grad().expect("gradient available");
let norm2 = RMSNorm::new(hidden, 1e-6);
let x2 = Tensor::from_vec(data, true);
let mut out2 = norm2.forward_batched(&x2, 1, hidden);
crate::autograd::backward(&mut out2, Some(ndarray::Array1::ones(hidden)));
let wgrad2 = norm2.weight.grad().expect("gradient available");
for i in 0..hidden {
assert!(
(wgrad1[i] - wgrad2[i]).abs() < 1e-5,
"Weight grad mismatch at [{i}]: unbatched={}, batched={}",
wgrad1[i],
wgrad2[i]
);
}
}
#[test]
fn enc_005_layernorm_output_shape() {
let ln = LayerNorm::new(8, 1e-5);
let x = Tensor::from_vec(vec![1.0; 3 * 8], true);
let output = ln.forward_batched(&x, 3, 8);
assert_eq!(output.len(), 3 * 8);
}
#[test]
fn enc_005_layernorm_zero_mean_unit_var() {
let ln = LayerNorm::new(8, 1e-12);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
let output = ln.forward_batched(&x, 1, 8);
let data = output.data();
let slice = data.as_slice().expect("contiguous");
let mean: f32 = slice.iter().sum::<f32>() / 8.0;
assert!(mean.abs() < 1e-5, "LayerNorm output mean={mean}, expected ~0");
let var: f32 = slice.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / 8.0;
assert!((var - 1.0).abs() < 0.01, "LayerNorm output var={var}, expected ~1");
}
#[test]
fn enc_005_layernorm_with_bias() {
let mut ln = LayerNorm::new(4, 1e-12);
ln.bias = Tensor::from_vec(vec![5.0; 4], true);
let x = Tensor::from_vec(vec![1.0, 1.0, 1.0, 1.0], true);
let output = ln.forward_batched(&x, 1, 4);
let data = output.data();
for &v in data {
assert!((v - 5.0).abs() < 1e-3, "Expected ~5.0 with bias, got {v}");
}
}
#[test]
fn enc_005_layernorm_from_params() {
let mut params = HashMap::new();
params.insert("ln.weight".to_string(), Tensor::from_vec(vec![1.0; 32], true));
params.insert("ln.bias".to_string(), Tensor::from_vec(vec![0.0; 32], true));
let ln = LayerNorm::from_params(¶ms, "ln", 1e-5, 32);
assert!(ln.is_some());
assert_eq!(ln.expect("should succeed").hidden_size(), 32);
}
#[test]
fn enc_005_layernorm_from_params_rejects_mismatch() {
let mut params = HashMap::new();
params.insert("ln.weight".to_string(), Tensor::from_vec(vec![1.0; 32], true));
params.insert("ln.bias".to_string(), Tensor::from_vec(vec![0.0; 16], true)); let ln = LayerNorm::from_params(¶ms, "ln", 1e-5, 32);
assert!(ln.is_none());
}
#[test]
fn enc_005_layernorm_finite_output() {
let ln = LayerNorm::new(4, 1e-5);
let x = Tensor::from_vec(vec![1e6, -1e6, 0.0, 1.0], true);
let output = ln.forward_batched(&x, 1, 4);
assert!(output.data().iter().all(|v| v.is_finite()));
}
#[test]
fn falsify_n1e_from_params_rejects_wrong_length_norm() {
let mut params = HashMap::new();
params.insert("test.weight".to_string(), Tensor::from_vec(vec![1.0; 7], true));
let norm = RMSNorm::from_params(¶ms, "test", 1e-6, 4);
assert!(
norm.is_none(),
"FALSIFY-N1e: PMAT-332 fix — from_params MUST reject wrong-length norm weight"
);
}
#[test]
fn falsify_n2e_norm_output_finite() {
let norm = RMSNorm::new(8, 1e-6);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], true);
let output = norm.forward(&x);
assert!(
output.data().iter().all(|v| v.is_finite()),
"FALSIFY-N2e: RMSNorm output must be finite for valid input"
);
}
#[test]
fn falsify_n3e_new_constructor_correct_length() {
let hidden_sizes = [64, 128, 256, 896, 4096];
for &hidden in &hidden_sizes {
let norm = RMSNorm::new(hidden, 1e-6);
assert_eq!(
norm.weight.len(),
hidden,
"FALSIFY-N3e: RMSNorm::new({hidden}) weight must have {hidden} elements"
);
}
}
#[test]
fn falsify_n4e_batched_forward_preserves_dims() {
let hidden = 8;
let seq_len = 4;
let norm = RMSNorm::new(hidden, 1e-6);
let x = Tensor::from_vec(vec![0.5; seq_len * hidden], true);
let output = norm.forward_batched(&x, seq_len, hidden);
assert_eq!(
output.len(),
seq_len * hidden,
"FALSIFY-N4e: Batched norm must preserve seq_len * hidden dimension"
);
assert!(
output.data().iter().all(|v| v.is_finite()),
"FALSIFY-N4e: Batched norm output must be finite"
);
}
#[test]
fn falsify_n5e_extreme_input_still_finite() {
let norm = RMSNorm::new(4, 1e-6);
let x = Tensor::from_vec(vec![1e30, -1e30, 1e30, -1e30], true);
let output = norm.forward(&x);
assert!(
output.data().iter().all(|v| v.is_finite()),
"FALSIFY-N5e: RMSNorm must handle extreme values without Inf/NaN"
);
}
#[test]
fn falsify_rn_001_finiteness() {
let norm = RMSNorm::new(4, 1e-6);
let test_cases: Vec<(&str, Vec<f32>)> = vec![
("normal", vec![1.0, 2.0, 3.0, 4.0]),
("small", vec![1e-7, 1e-7, 1e-7, 1e-7]),
("large", vec![1e6, 1e6, 1e6, 1e6]),
("mixed_sign", vec![-3.0, 2.0, -1.0, 4.0]),
("near_zero", vec![1e-20, 0.0, 1e-20, 0.0]),
];
for (name, data) in &test_cases {
let x = Tensor::from_vec(data.clone(), true);
let y = norm.forward(&x);
for (i, &val) in y.data().iter().enumerate() {
assert!(
val.is_finite(),
"FALSIFIED RN-001: output[{i}] = {val} not finite for case '{name}'"
);
}
}
}
#[test]
fn falsify_rn_002_scale_invariance() {
let norm = RMSNorm::new(4, 1e-6);
let x = Tensor::from_vec(vec![1.0, -2.0, 3.0, -0.5], true);
let y_base = norm.forward(&x);
for &alpha in &[2.0_f32, 0.5, -1.0, 10.0] {
let x_scaled = Tensor::from_vec(x.data().iter().map(|&v| v * alpha).collect(), true);
let y_scaled = norm.forward(&x_scaled);
let sign = alpha.signum();
for (i, (&ys, &yb)) in y_scaled.data().iter().zip(y_base.data().iter()).enumerate() {
let expected = sign * yb;
let diff = (ys - expected).abs();
assert!(
diff < 1e-3,
"FALSIFIED RN-002: RMSNorm({alpha}·x)[{i}] = {ys}, expected {expected}"
);
}
}
}
#[test]
fn falsify_rn_004_zero_vector() {
let norm = RMSNorm::new(4, 1e-6);
let x = Tensor::from_vec(vec![0.0, 0.0, 0.0, 0.0], true);
let y = norm.forward(&x);
for (i, &val) in y.data().iter().enumerate() {
assert!(val.is_finite(), "FALSIFIED RN-004: RMSNorm(0)[{i}] = {val} (expected finite)");
}
}
#[test]
fn falsify_rn_005_unit_gamma_normalized_rms() {
let norm = RMSNorm::new(8, 1e-6);
let x = Tensor::from_vec(vec![1.0, -2.0, 3.0, -0.5, 4.0, -1.0, 2.5, -3.0], true);
let y = norm.forward(&x);
let y_data = y.data();
let rms_out: f32 =
(y_data.iter().map(|&v| v * v).sum::<f32>() / y_data.len() as f32).sqrt();
assert!(
(rms_out - 1.0).abs() < 0.01,
"FALSIFIED RN-005: RMS(RMSNorm(x)) = {rms_out}, expected ≈ 1.0"
);
}
mod rn_proptest_falsify {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(200))]
#[test]
fn falsify_rn_001_prop_finiteness(
dim in prop::sample::select(vec![4_usize, 8, 16, 32, 64]),
scale in 0.001_f32..1000.0,
) {
let norm = RMSNorm::new(dim, 1e-6);
let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.13 * scale).sin()).collect();
let x = Tensor::from_vec(data, true);
let y = norm.forward(&x);
for (i, &val) in y.data().iter().enumerate() {
prop_assert!(
val.is_finite(),
"FALSIFIED RN-001-prop: output[{}]={} not finite (d={}, scale={})",
i, val, dim, scale
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn falsify_rn_002_prop_scale_invariance(
dim in prop::sample::select(vec![4_usize, 8, 16, 32]),
alpha in prop::sample::select(vec![-10.0_f32, -1.0, 0.5, 2.0, 100.0]),
) {
let norm = RMSNorm::new(dim, 1e-6);
let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.37).sin() * 5.0).collect();
let x = Tensor::from_vec(data.clone(), true);
let y_base = norm.forward(&x);
let x_scaled = Tensor::from_vec(
data.iter().map(|&v| v * alpha).collect(),
true,
);
let y_scaled = norm.forward(&x_scaled);
let sign = alpha.signum();
for (i, (&ys, &yb)) in y_scaled.data().iter().zip(y_base.data().iter()).enumerate() {
let expected = sign * yb;
prop_assert!(
(ys - expected).abs() < 1e-3,
"FALSIFIED RN-002-prop: [{i}] got {ys}, expected {expected} (alpha={alpha}, d={dim})"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn falsify_rn_005_prop_unit_gamma_rms(
dim in prop::sample::select(vec![8_usize, 16, 32, 64]),
) {
let norm = RMSNorm::new(dim, 1e-6);
let data: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.23).sin() * 10.0).collect();
let x = Tensor::from_vec(data, true);
let y = norm.forward(&x);
let y_data = y.data();
let rms_out: f32 = (y_data.iter().map(|&v| v * v).sum::<f32>() / y_data.len() as f32).sqrt();
prop_assert!(
(rms_out - 1.0).abs() < 0.05,
"FALSIFIED RN-005-prop: RMS(output)={} != 1.0 (d={})",
rms_out, dim
);
}
}
}
}