use crate::error::Result;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct LayerNorm {
pub gain: Tensor,
pub bias: Tensor,
pub eps: f64,
pub feature_dim: usize,
}
impl LayerNorm {
pub fn new(feature_dim: usize) -> Self {
Self {
gain: Tensor::ones(&[feature_dim]),
bias: Tensor::zeros(&[feature_dim]),
eps: 1e-5,
feature_dim,
}
}
pub fn with_eps(feature_dim: usize, eps: f64) -> Self {
Self {
eps,
..Self::new(feature_dim)
}
}
pub fn forward_1d(&self, x: &Tensor) -> Result<Tensor> {
let mean = x.mean()?;
let var = x.variance()?;
let inv_std = 1.0 / (var + self.eps).sqrt();
let centered = x.map(|v| (v - mean) * inv_std);
let scaled = centered.mul(&self.gain)?;
scaled.add(&self.bias)
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let shape = x.shape();
if shape.len() == 1 {
return self.forward_1d(x);
}
let rows = shape[0];
let cols = shape[1];
let mut result_data = Vec::with_capacity(rows * cols);
for r in 0..rows {
let row = x.row(r)?;
let normed = self.forward_1d(&row)?;
result_data.extend_from_slice(normed.data());
}
Tensor::new(result_data, vec![rows, cols])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm_1d() {
let ln = LayerNorm::new(4);
let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let y = ln.forward_1d(&x).unwrap();
let mean = y.mean().unwrap();
assert!(mean.abs() < 1e-6, "mean={mean}");
}
#[test]
fn test_layer_norm_2d() {
let ln = LayerNorm::new(4);
let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 4]).unwrap();
let y = ln.forward(&x).unwrap();
assert_eq!(y.shape(), &[2, 4]);
}
}