use crate::autograd::Variable;
use crate::nn::Module;
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug)]
pub struct LayerNorm<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
weight: Variable<T>,
bias: Variable<T>,
normalized_shape: Vec<usize>,
eps: T,
elementwise_affine: bool,
}
impl<T> LayerNorm<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ndarray::ScalarOperand,
{
pub fn new(
normalized_shape: Vec<usize>,
eps: Option<T>,
elementwise_affine: Option<bool>,
) -> Self {
assert!(
!normalized_shape.is_empty(),
"normalized_shape cannot be empty"
);
let eps = eps.unwrap_or_else(|| T::from_f32(1e-5).unwrap());
let elementwise_affine = elementwise_affine.unwrap_or(true);
let num_features: usize = normalized_shape.iter().product();
let weight_data = vec![T::one(); num_features];
let weight = Variable::new(
Tensor::from_vec(weight_data, normalized_shape.clone()),
elementwise_affine,
);
let bias_data = vec![T::zero(); num_features];
let bias = Variable::new(
Tensor::from_vec(bias_data, normalized_shape.clone()),
elementwise_affine,
);
LayerNorm {
weight,
bias,
normalized_shape,
eps,
elementwise_affine,
}
}
pub fn forward(&self, input: &Variable<T>) -> Variable<T> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_shape = input_data.shape();
self.verify_input_shape(input_shape);
let normalized_data = self.layer_normalize(&input_data);
let requires_grad = input.requires_grad()
|| (self.elementwise_affine
&& (self.weight.requires_grad() || self.bias.requires_grad()));
Variable::new(normalized_data, requires_grad)
}
fn verify_input_shape(&self, input_shape: &[usize]) {
let norm_dims = self.normalized_shape.len();
let input_dims = input_shape.len();
if input_dims < norm_dims {
panic!(
"Input has {} dimensions but normalized_shape has {} dimensions",
input_dims, norm_dims
);
}
let input_suffix = &input_shape[input_dims - norm_dims..];
if input_suffix != self.normalized_shape.as_slice() {
panic!(
"Input shape suffix {:?} doesn't match normalized_shape {:?}",
input_suffix, self.normalized_shape
);
}
}
fn layer_normalize(&self, input: &Tensor<T>) -> Tensor<T> {
let input_array = input.as_array();
let input_shape = input.shape();
let norm_dims = self.normalized_shape.len();
let input_dims = input_shape.len();
let batch_dims = input_dims - norm_dims;
let batch_size: usize = input_shape[..batch_dims].iter().product();
let feature_size: usize = self.normalized_shape.iter().product();
let mut output_data = Vec::with_capacity(input_array.len());
for batch_idx in 0..batch_size {
let mut features = Vec::with_capacity(feature_size);
for feat_idx in 0..feature_size {
let linear_idx = batch_idx * feature_size + feat_idx;
if let Some(slice) = input_array.as_slice() {
features.push(slice[linear_idx]);
} else {
let indices = self.unravel_index(linear_idx, input_shape);
features.push(input_array[indices.as_slice()]);
}
}
let mean = self.calculate_mean(&features);
let variance = self.calculate_variance(&features, mean);
let std = (variance + self.eps).sqrt();
for (feat_idx, &feature_val) in features.iter().enumerate() {
let normalized = (feature_val - mean) / std;
let final_val = if self.elementwise_affine {
let weight_binding = self.weight.data();
let weight_data = weight_binding.read().unwrap();
let bias_binding = self.bias.data();
let bias_data = bias_binding.read().unwrap();
let weight_indices = self.unravel_index(feat_idx, &self.normalized_shape);
let bias_indices = weight_indices.clone();
let gamma = weight_data.as_array()[weight_indices.as_slice()];
let beta = bias_data.as_array()[bias_indices.as_slice()];
gamma * normalized + beta
} else {
normalized
};
output_data.push(final_val);
}
}
Tensor::from_vec(output_data, input_shape.to_vec())
}
fn calculate_mean(&self, features: &[T]) -> T {
let sum: T = features.iter().fold(T::zero(), |acc, &x| acc + x);
sum / T::from_usize(features.len()).unwrap()
}
fn calculate_variance(&self, features: &[T], mean: T) -> T {
let sum_sq_diff: T = features
.iter()
.fold(T::zero(), |acc, &x| acc + (x - mean).powi(2));
sum_sq_diff / T::from_usize(features.len()).unwrap()
}
fn unravel_index(&self, mut index: usize, shape: &[usize]) -> Vec<usize> {
let mut indices = vec![0; shape.len()];
for i in (0..shape.len()).rev() {
indices[i] = index % shape[i];
index /= shape[i];
}
indices
}
pub fn normalized_shape(&self) -> &[usize] {
&self.normalized_shape
}
pub fn eps(&self) -> T {
self.eps
}
pub fn elementwise_affine(&self) -> bool {
self.elementwise_affine
}
pub fn parameters(&self) -> Vec<Variable<T>> {
if self.elementwise_affine {
vec![self.weight.clone(), self.bias.clone()]
} else {
vec![]
}
}
}
impl<T> Module<T> for LayerNorm<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug)]
pub struct GroupNorm<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
weight: Variable<T>,
bias: Variable<T>,
num_groups: usize,
num_channels: usize,
eps: T,
affine: bool,
}
impl<T> GroupNorm<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
pub fn new(
num_groups: usize,
num_channels: usize,
eps: Option<T>,
affine: Option<bool>,
) -> Self {
assert!(num_groups > 0, "num_groups must be greater than 0");
assert!(num_channels > 0, "num_channels must be greater than 0");
assert!(
num_channels % num_groups == 0,
"num_channels ({}) must be divisible by num_groups ({})",
num_channels,
num_groups
);
let eps = eps.unwrap_or_else(|| T::from_f32(1e-5).unwrap());
let affine = affine.unwrap_or(true);
let weight_data = vec![T::one(); num_channels];
let weight = Variable::new(Tensor::from_vec(weight_data, vec![num_channels]), affine);
let bias_data = vec![T::zero(); num_channels];
let bias = Variable::new(Tensor::from_vec(bias_data, vec![num_channels]), affine);
GroupNorm {
weight,
bias,
num_groups,
num_channels,
eps,
affine,
}
}
pub fn forward(&self, input: &Variable<T>) -> Variable<T> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_shape = input_data.shape();
if input_shape.len() < 3 {
panic!(
"GroupNorm expects at least 3D input (N, C, ...), got {:?}",
input_shape
);
}
let _batch_size = input_shape[0];
let channels = input_shape[1];
if channels != self.num_channels {
panic!(
"Input channels {} doesn't match layer channels {}",
channels, self.num_channels
);
}
let normalized_data = self.group_normalize(&input_data);
let requires_grad = input.requires_grad()
|| (self.affine && (self.weight.requires_grad() || self.bias.requires_grad()));
Variable::new(normalized_data, requires_grad)
}
fn group_normalize(&self, input: &Tensor<T>) -> Tensor<T> {
let input_array = input.as_array();
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let spatial_size: usize = input_shape[2..].iter().product();
let channels_per_group = channels / self.num_groups;
let group_size = channels_per_group * spatial_size;
let mut output_data = Vec::with_capacity(input_array.len());
for b in 0..batch_size {
for g in 0..self.num_groups {
let group_start_channel = g * channels_per_group;
let group_end_channel = (g + 1) * channels_per_group;
let mut group_values = Vec::with_capacity(group_size);
for c in group_start_channel..group_end_channel {
for _s in 0..spatial_size {
let mut indices = vec![b, c];
let spatial_indices = self.unravel_spatial_index(_s, &input_shape[2..]);
indices.extend(spatial_indices);
group_values.push(input_array[indices.as_slice()]);
}
}
let mean = self.calculate_mean(&group_values);
let variance = self.calculate_variance(&group_values, mean);
let std = (variance + self.eps).sqrt();
let mut value_idx = 0;
for c in group_start_channel..group_end_channel {
for _s in 0..spatial_size {
let normalized = (group_values[value_idx] - mean) / std;
let final_val = if self.affine {
let weight_binding = self.weight.data();
let weight_data = weight_binding.read().unwrap();
let bias_binding = self.bias.data();
let bias_data = bias_binding.read().unwrap();
let gamma = weight_data.as_array()[[c]];
let beta = bias_data.as_array()[[c]];
gamma * normalized + beta
} else {
normalized
};
output_data.push(final_val);
value_idx += 1;
}
}
}
}
Tensor::from_vec(output_data, input_shape.to_vec())
}
fn unravel_spatial_index(&self, mut index: usize, spatial_shape: &[usize]) -> Vec<usize> {
let mut indices = vec![0; spatial_shape.len()];
for i in (0..spatial_shape.len()).rev() {
indices[i] = index % spatial_shape[i];
index /= spatial_shape[i];
}
indices
}
fn calculate_mean(&self, values: &[T]) -> T {
let sum: T = values.iter().fold(T::zero(), |acc, &x| acc + x);
sum / T::from_usize(values.len()).unwrap()
}
fn calculate_variance(&self, values: &[T], mean: T) -> T {
let sum_sq_diff: T = values
.iter()
.fold(T::zero(), |acc, &x| acc + (x - mean).powi(2));
sum_sq_diff / T::from_usize(values.len()).unwrap()
}
pub fn num_groups(&self) -> usize {
self.num_groups
}
pub fn num_channels(&self) -> usize {
self.num_channels
}
pub fn eps(&self) -> T {
self.eps
}
pub fn affine(&self) -> bool {
self.affine
}
pub fn parameters(&self) -> Vec<Variable<T>> {
if self.affine {
vec![self.weight.clone(), self.bias.clone()]
} else {
vec![]
}
}
}
#[derive(Debug)]
pub struct RMSNorm<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
weight: Variable<T>,
normalized_shape: Vec<usize>,
eps: T,
}
impl<T> RMSNorm<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
pub fn new(normalized_shape: Vec<usize>, eps: Option<T>) -> Self {
assert!(
!normalized_shape.is_empty(),
"normalized_shape cannot be empty"
);
let eps = eps.unwrap_or_else(|| T::from(1e-8).unwrap());
let num_features: usize = normalized_shape.iter().product();
let weight_data = vec![T::one(); num_features];
let weight = Variable::new(
Tensor::from_vec(weight_data, normalized_shape.clone()),
true,
);
RMSNorm {
weight,
normalized_shape,
eps,
}
}
pub fn forward(&self, input: &Variable<T>) -> Variable<T> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_shape = input_data.shape();
self.verify_input_shape(input_shape);
let normalized_data = self.rms_normalize(&input_data);
let requires_grad = input.requires_grad() || self.weight.requires_grad();
Variable::new(normalized_data, requires_grad)
}
fn verify_input_shape(&self, input_shape: &[usize]) {
let norm_dims = self.normalized_shape.len();
let input_dims = input_shape.len();
if input_dims < norm_dims {
panic!(
"Input has {} dimensions but normalized_shape has {} dimensions",
input_dims, norm_dims
);
}
let input_suffix = &input_shape[input_dims - norm_dims..];
if input_suffix != self.normalized_shape.as_slice() {
panic!(
"Input shape suffix {:?} doesn't match normalized_shape {:?}",
input_suffix, self.normalized_shape
);
}
}
fn rms_normalize(&self, input: &Tensor<T>) -> Tensor<T> {
let input_array = input.as_array();
let input_shape = input.shape();
let norm_dims = self.normalized_shape.len();
let input_dims = input_shape.len();
let batch_dims = input_dims - norm_dims;
let batch_size: usize = input_shape[..batch_dims].iter().product();
let feature_size: usize = self.normalized_shape.iter().product();
let mut output_data = Vec::with_capacity(input_array.len());
for batch_idx in 0..batch_size {
let mut features = Vec::with_capacity(feature_size);
for feat_idx in 0..feature_size {
let linear_idx = batch_idx * feature_size + feat_idx;
if let Some(slice) = input_array.as_slice() {
features.push(slice[linear_idx]);
} else {
let indices = self.unravel_index(linear_idx, input_shape);
features.push(input_array[indices.as_slice()]);
}
}
let mean_square: T = features.iter().fold(T::zero(), |acc, &x| acc + x.powi(2))
/ T::from_usize(features.len()).unwrap();
let rms = (mean_square + self.eps).sqrt();
for (feat_idx, &feature_val) in features.iter().enumerate() {
let normalized = feature_val / rms;
let weight_binding = self.weight.data();
let weight_data = weight_binding.read().unwrap();
let weight_indices = self.unravel_index(feat_idx, &self.normalized_shape);
let gamma = weight_data.as_array()[weight_indices.as_slice()];
let final_val = gamma * normalized;
output_data.push(final_val);
}
}
Tensor::from_vec(output_data, input_shape.to_vec())
}
fn unravel_index(&self, mut index: usize, shape: &[usize]) -> Vec<usize> {
let mut indices = vec![0; shape.len()];
for i in (0..shape.len()).rev() {
indices[i] = index % shape[i];
index /= shape[i];
}
indices
}
pub fn eps(&self) -> T {
self.eps
}
pub fn parameters(&self) -> Vec<Variable<T>> {
vec![self.weight.clone()]
}
}
impl<T> Module<T> for GroupNorm<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display
+ num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T> Module<T> for RMSNorm<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input)
}
fn parameters(&self) -> Vec<Variable<T>> {
vec![self.weight.clone()]
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm_creation() {
let layer_norm = LayerNorm::<f32>::new(vec![128], None, None);
assert_eq!(layer_norm.normalized_shape(), &[128]);
assert!(layer_norm.elementwise_affine());
let params = layer_norm.parameters();
assert_eq!(params.len(), 2); }
#[test]
fn test_layer_norm_forward() {
let layer_norm = LayerNorm::<f32>::new(vec![4], None, None);
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 4]),
false,
);
let output = layer_norm.forward(&input);
let output_binding = output.data();
let output_data = output_binding.read().unwrap();
assert_eq!(output_data.shape(), &[2, 4]);
}
#[test]
fn test_group_norm_creation() {
let group_norm = GroupNorm::<f32>::new(2, 8, None, None);
assert_eq!(group_norm.num_groups, 2);
assert_eq!(group_norm.num_channels, 8);
assert!(group_norm.affine);
let params = group_norm.parameters();
assert_eq!(params.len(), 2); }
#[test]
fn test_rms_norm_creation() {
let rms_norm = RMSNorm::<f32>::new(vec![64], None);
assert_eq!(rms_norm.normalized_shape, vec![64]);
let params = rms_norm.parameters();
assert_eq!(params.len(), 1); }
}