use super::init::{constant, zeros};
use super::module::Module;
use crate::autograd::Tensor;
#[derive(Debug)]
pub struct LayerNorm {
normalized_shape: Vec<usize>,
eps: f32,
weight: Tensor,
bias: Tensor,
elementwise_affine: bool,
}
impl LayerNorm {
#[must_use]
pub fn new(normalized_shape: &[usize]) -> Self {
let numel: usize = normalized_shape.iter().product();
Self {
normalized_shape: normalized_shape.to_vec(),
eps: 1e-5,
weight: constant(&[numel], 1.0).requires_grad(),
bias: zeros(&[numel]).requires_grad(),
elementwise_affine: true,
}
}
#[must_use]
pub fn with_eps(normalized_shape: &[usize], eps: f32) -> Self {
let mut layer = Self::new(normalized_shape);
layer.eps = eps;
layer
}
#[must_use]
pub fn without_affine(normalized_shape: &[usize]) -> Self {
let numel: usize = normalized_shape.iter().product();
Self {
normalized_shape: normalized_shape.to_vec(),
eps: 1e-5,
weight: constant(&[numel], 1.0),
bias: zeros(&[numel]),
elementwise_affine: false,
}
}
#[must_use]
pub fn normalized_shape(&self) -> &[usize] {
&self.normalized_shape
}
}
impl Module for LayerNorm {
#[provable_contracts_macros::contract("layernorm-kernel-v1", equation = "layernorm")]
fn forward(&self, input: &Tensor) -> Tensor {
let shape = input.shape();
let norm_size: usize = self.normalized_shape.iter().product();
assert!(
shape.len() >= self.normalized_shape.len(),
"Input must have at least as many dimensions as normalized_shape"
);
let start_dim = shape.len() - self.normalized_shape.len();
for (i, &ns) in self.normalized_shape.iter().enumerate() {
assert_eq!(
shape[start_dim + i],
ns,
"Input shape doesn't match normalized_shape at dim {i}"
);
}
if self.elementwise_affine {
crate::nn::functional::layer_norm(input, &self.weight, &self.bias, self.eps)
} else {
let batch_dims: usize = shape[..start_dim].iter().product();
let input_data = input.data();
let mut output_data = vec![0.0; input_data.len()];
for b in 0..batch_dims {
let offset = b * norm_size;
let slice = &input_data[offset..offset + norm_size];
let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
let var: f32 =
slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
let std_inv = 1.0 / (var + self.eps).sqrt();
for i in 0..norm_size {
output_data[offset + i] = (slice[i] - mean) * std_inv;
}
}
Tensor::new(&output_data, shape)
}
}
fn parameters(&self) -> Vec<&Tensor> {
if self.elementwise_affine {
vec![&self.weight, &self.bias]
} else {
vec![]
}
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
if self.elementwise_affine {
vec![&mut self.weight, &mut self.bias]
} else {
vec![]
}
}
}
#[derive(Debug)]
pub struct BatchNorm1d {
num_features: usize,
eps: f32,
momentum: f32,
weight: Tensor,
bias: Tensor,
running_mean: Tensor,
running_var: Tensor,
training: bool,
}
impl BatchNorm1d {
#[must_use]
pub fn new(num_features: usize) -> Self {
Self {
num_features,
eps: 1e-5,
momentum: 0.1,
weight: constant(&[num_features], 1.0).requires_grad(),
bias: zeros(&[num_features]).requires_grad(),
running_mean: zeros(&[num_features]),
running_var: constant(&[num_features], 1.0),
training: true,
}
}
#[must_use]
pub fn with_momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
#[must_use]
pub fn with_eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
}
impl BatchNorm1d {
fn feature_indices(shape: &[usize], feature: usize) -> Vec<usize> {
let (batch_size, features) = (shape[0], shape[1]);
if shape.len() == 2 {
(0..batch_size).map(|b| b * features + feature).collect()
} else {
let length = shape[2];
let mut indices = Vec::with_capacity(batch_size * length);
for b in 0..batch_size {
for l in 0..length {
indices.push(b * features * length + feature * length + l);
}
}
indices
}
}
fn normalize_feature(
input_data: &[f32],
output_data: &mut [f32],
indices: &[usize],
mean: f32,
std_inv: f32,
gamma: f32,
beta: f32,
) {
for &idx in indices {
let normalized = (input_data[idx] - mean) * std_inv;
output_data[idx] = normalized * gamma + beta;
}
}
}
impl Module for BatchNorm1d {
#[provable_contracts_macros::contract("batchnorm-kernel-v1", equation = "batchnorm_train")]
fn forward(&self, input: &Tensor) -> Tensor {
assert!(
input.ndim() == 2 || input.ndim() == 3,
"BatchNorm1d expects 2D or 3D input, got {}D",
input.ndim()
);
let shape = input.shape();
let features = shape[1];
assert_eq!(
features, self.num_features,
"Expected {} features, got {}",
self.num_features, features
);
let input_data = input.data();
let mut output_data = vec![0.0; input_data.len()];
for f in 0..features {
let indices = Self::feature_indices(shape, f);
let (mean, var) = if self.training {
let sum: f32 = indices.iter().map(|&i| input_data[i]).sum();
let mean = sum / indices.len() as f32;
let var_sum: f32 = indices
.iter()
.map(|&i| (input_data[i] - mean).powi(2))
.sum();
(mean, var_sum / indices.len() as f32)
} else {
(self.running_mean.data()[f], self.running_var.data()[f])
};
let std_inv = 1.0 / (var + self.eps).sqrt();
Self::normalize_feature(
input_data,
&mut output_data,
&indices,
mean,
std_inv,
self.weight.data()[f],
self.bias.data()[f],
);
}
Tensor::new(&output_data, shape)
}
fn parameters(&self) -> Vec<&Tensor> {
vec![&self.weight, &self.bias]
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
vec![&mut self.weight, &mut self.bias]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
pub struct GroupNorm {
num_groups: usize,
num_channels: usize,
eps: f32,
weight: Tensor,
bias: Tensor,
affine: bool,
}
mod group_norm;
pub use group_norm::*;