use crate::tensor::DenseTensor;
use crate::tensor::traits::{TensorOps, TensorBase};
#[derive(Debug, Clone)]
pub struct LayerNorm {
pub weight: DenseTensor,
pub bias: DenseTensor,
pub epsilon: f64,
pub hidden_dim: usize,
}
impl LayerNorm {
pub fn new(weight: DenseTensor, bias: DenseTensor, epsilon: f64) -> Self {
let hidden_dim = weight.shape()[0];
Self {
weight,
bias,
epsilon,
hidden_dim,
}
}
pub fn default(hidden_dim: usize) -> Self {
let weight = DenseTensor::ones(vec![hidden_dim]);
let bias = DenseTensor::zeros(vec![hidden_dim]);
Self::new(weight, bias, 1e-5)
}
pub fn forward(&self, x: &DenseTensor) -> DenseTensor {
let ndim = x.ndim();
if ndim == 3 {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let hidden_dim = x.shape()[2];
let mean = x.mean_dim(-1);
let mean_expanded = mean.expand_last_dim(hidden_dim);
let var = x.var_dim(-1);
let var_expanded = var.expand_last_dim(hidden_dim);
let eps_tensor = DenseTensor::full(&[batch_size, seq_len, hidden_dim], self.epsilon);
let var_with_eps = var_expanded.add(&eps_tensor);
let std = var_with_eps.sqrt();
let centered = x.sub(&mean_expanded);
let normalized = centered.div(&std);
let weight_expanded = self.weight.expand_to_3d(batch_size, seq_len);
let bias_expanded = self.bias.expand_to_3d(batch_size, seq_len);
normalized.mul(&weight_expanded).add(&bias_expanded)
} else if ndim == 2 {
let seq_len = x.shape()[0];
let hidden_dim = x.shape()[1];
let mean = x.mean_dim(-1);
let mean_expanded = mean.expand_last_dim_2d(hidden_dim);
let var = x.var_dim(-1);
let var_expanded = var.expand_last_dim_2d(hidden_dim);
let eps_tensor = DenseTensor::full(&[seq_len, hidden_dim], self.epsilon);
let var_with_eps = var_expanded.add(&eps_tensor);
let std = var_with_eps.sqrt();
let centered = x.sub(&mean_expanded);
let normalized = centered.div(&std);
let weight_expanded = self.weight.expand_to_2d(seq_len);
let bias_expanded = self.bias.expand_to_2d(seq_len);
normalized.mul(&weight_expanded).add(&bias_expanded)
} else {
panic!("LayerNorm only supports 2D or 3D tensors");
}
}
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
}
#[derive(Debug, Clone)]
pub struct RMSNorm {
pub weight: DenseTensor,
pub epsilon: f64,
pub hidden_dim: usize,
}
impl RMSNorm {
pub fn new(weight: DenseTensor, epsilon: f64) -> Self {
let hidden_dim = weight.shape()[0];
Self {
weight,
epsilon,
hidden_dim,
}
}
pub fn default(hidden_dim: usize) -> Self {
let weight = DenseTensor::ones(vec![hidden_dim]);
Self::new(weight, 1e-6)
}
pub fn forward(&self, x: &DenseTensor) -> DenseTensor {
let ndim = x.ndim();
if ndim == 3 {
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let hidden_dim = x.shape()[2];
let x_squared = x.mul(x);
let mean_square = x_squared.mean_dim(-1);
let mean_square_expanded = mean_square.expand_last_dim(hidden_dim);
let eps_tensor = DenseTensor::full(&[batch_size, seq_len, hidden_dim], self.epsilon);
let rms_input = mean_square_expanded.add(&eps_tensor);
let rms = rms_input.sqrt();
let normalized = x.div(&rms);
let weight_expanded = self.weight.expand_to_3d(batch_size, seq_len);
normalized.mul(&weight_expanded)
} else if ndim == 2 {
let seq_len = x.shape()[0];
let hidden_dim = x.shape()[1];
let x_squared = x.mul(x);
let mean_square = x_squared.mean_dim(-1);
let mean_square_expanded = mean_square.expand_last_dim_2d(hidden_dim);
let eps_tensor = DenseTensor::full(&[seq_len, hidden_dim], self.epsilon);
let rms_input = mean_square_expanded.add(&eps_tensor);
let rms = rms_input.sqrt();
let normalized = x.div(&rms);
let weight_expanded = self.weight.expand_to_2d(seq_len);
normalized.mul(&weight_expanded)
} else {
panic!("RMSNorm only supports 2D or 3D tensors");
}
}
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm() {
let hidden_dim = 8;
let layer_norm = LayerNorm::default(hidden_dim);
let batch_size = 2;
let seq_len = 4;
let x = DenseTensor::new(
(0..batch_size * seq_len * hidden_dim).map(|i| (i % 10) as f64).collect(),
vec![batch_size, seq_len, hidden_dim],
);
let output = layer_norm.forward(&x);
assert_eq!(output.shape(), &[batch_size, seq_len, hidden_dim]);
let mean = output.mean_dim(-1);
let var = output.var_dim(-1);
for i in 0..mean.shape()[0] {
for j in 0..mean.shape()[1] {
let m = mean.data()[i * mean.shape()[1] + j];
assert!(m.abs() < 1e-5, "Mean should be close to 0, got {}", m);
}
}
for i in 0..var.shape()[0] {
for j in 0..var.shape()[1] {
let v = var.data()[i * var.shape()[1] + j];
assert!((v - 1.0).abs() < 0.1, "Variance should be close to 1, got {}", v);
}
}
}
#[test]
fn test_rms_norm() {
let hidden_dim = 8;
let rms_norm = RMSNorm::default(hidden_dim);
let batch_size = 2;
let seq_len = 4;
let x = DenseTensor::new(
(0..batch_size * seq_len * hidden_dim).map(|i| (i % 10) as f64).collect(),
vec![batch_size, seq_len, hidden_dim],
);
let output = rms_norm.forward(&x);
assert_eq!(output.shape(), &[batch_size, seq_len, hidden_dim]);
let output_squared = output.mul(&output);
let mean_square = output_squared.mean_dim(-1);
for i in 0..mean_square.shape()[0] {
for j in 0..mean_square.shape()[1] {
let ms = mean_square.data()[i * mean_square.shape()[1] + j];
assert!((ms - 1.0).abs() < 0.1, "RMS^2 should be close to 1, got {}", ms);
}
}
}
#[test]
fn test_layer_norm_with_custom_params() {
let hidden_dim = 4;
let weight = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![hidden_dim]);
let bias = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![hidden_dim]);
let layer_norm = LayerNorm::new(weight, bias, 1e-5);
let x = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, hidden_dim]);
let output = layer_norm.forward(&x);
assert_eq!(output.shape(), &[1, hidden_dim]);
}
}