#[allow(clippy::wildcard_imports)]
use super::*;
impl GroupNorm {
#[must_use]
pub fn new(num_groups: usize, num_channels: usize) -> Self {
assert!(
num_channels.is_multiple_of(num_groups),
"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
);
Self {
num_groups,
num_channels,
eps: 1e-5,
weight: constant(&[num_channels], 1.0).requires_grad(),
bias: zeros(&[num_channels]).requires_grad(),
affine: true,
}
}
#[must_use]
pub fn with_eps(num_groups: usize, num_channels: usize, eps: f32) -> Self {
let mut layer = Self::new(num_groups, num_channels);
layer.eps = eps;
layer
}
#[must_use]
pub fn without_affine(num_groups: usize, num_channels: usize) -> Self {
assert!(
num_channels.is_multiple_of(num_groups),
"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
);
Self {
num_groups,
num_channels,
eps: 1e-5,
weight: constant(&[num_channels], 1.0),
bias: zeros(&[num_channels]),
affine: false,
}
}
#[must_use]
pub fn num_groups(&self) -> usize {
self.num_groups
}
#[must_use]
pub fn num_channels(&self) -> usize {
self.num_channels
}
}
impl Module for GroupNorm {
fn forward(&self, input: &Tensor) -> Tensor {
let shape = input.shape();
assert!(
shape.len() >= 2,
"GroupNorm expects at least 2D input, got {}D",
shape.len()
);
let (batch_size, channels) = (shape[0], shape[1]);
assert_eq!(
channels, self.num_channels,
"Expected {} channels, got {}",
self.num_channels, channels
);
let channels_per_group = channels / self.num_groups;
let spatial_size: usize = shape[2..].iter().product();
let group_size = channels_per_group * spatial_size;
let input_data = input.data();
let mut output_data = vec![0.0; input_data.len()];
for n in 0..batch_size {
for g in 0..self.num_groups {
let mut sum = 0.0;
for c in 0..channels_per_group {
let channel_idx = g * channels_per_group + c;
for s in 0..spatial_size {
let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
sum += input_data[idx];
}
}
let mean = sum / group_size as f32;
let mut var_sum = 0.0;
for c in 0..channels_per_group {
let channel_idx = g * channels_per_group + c;
for s in 0..spatial_size {
let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
var_sum += (input_data[idx] - mean).powi(2);
}
}
let var = var_sum / group_size as f32;
let std_inv = 1.0 / (var + self.eps).sqrt();
for c in 0..channels_per_group {
let channel_idx = g * channels_per_group + c;
for s in 0..spatial_size {
let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
let normalized = (input_data[idx] - mean) * std_inv;
output_data[idx] = if self.affine {
normalized * self.weight.data()[channel_idx]
+ self.bias.data()[channel_idx]
} else {
normalized
};
}
}
}
}
Tensor::new(&output_data, shape)
}
fn parameters(&self) -> Vec<&Tensor> {
if self.affine {
vec![&self.weight, &self.bias]
} else {
vec![]
}
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
if self.affine {
vec![&mut self.weight, &mut self.bias]
} else {
vec![]
}
}
}
#[derive(Debug)]
pub struct RMSNorm {
normalized_shape: Vec<usize>,
eps: f32,
weight: Tensor,
elementwise_affine: bool,
}
impl RMSNorm {
#[must_use]
pub fn new(normalized_shape: &[usize]) -> Self {
let numel: usize = normalized_shape.iter().product();
Self {
normalized_shape: normalized_shape.to_vec(),
eps: 1e-6, weight: constant(&[numel], 1.0).requires_grad(),
elementwise_affine: true,
}
}
#[must_use]
pub fn with_eps(normalized_shape: &[usize], eps: f32) -> Self {
let mut layer = Self::new(normalized_shape);
layer.eps = eps;
layer
}
#[must_use]
pub fn without_affine(normalized_shape: &[usize]) -> Self {
let numel: usize = normalized_shape.iter().product();
Self {
normalized_shape: normalized_shape.to_vec(),
eps: 1e-6,
weight: constant(&[numel], 1.0),
elementwise_affine: false,
}
}
#[must_use]
pub fn normalized_shape(&self) -> &[usize] {
&self.normalized_shape
}
#[must_use]
pub fn eps(&self) -> f32 {
self.eps
}
pub fn set_weight(&mut self, weight: Tensor) {
self.weight = weight;
}
#[must_use]
pub fn weight(&self) -> &Tensor {
&self.weight
}
#[must_use]
pub fn placeholder(normalized_shape: &[usize]) -> Self {
Self {
normalized_shape: normalized_shape.to_vec(),
eps: 1e-6,
weight: Tensor::new(&[1.0], &[1]),
elementwise_affine: true,
}
}
}
impl Module for RMSNorm {
#[provable_contracts_macros::contract("rmsnorm-kernel-v1", equation = "rmsnorm")]
fn forward(&self, input: &Tensor) -> Tensor {
let shape = input.shape();
let norm_size: usize = self.normalized_shape.iter().product();
assert!(
shape.len() >= self.normalized_shape.len(),
"Input must have at least as many dimensions as normalized_shape"
);
let start_dim = shape.len() - self.normalized_shape.len();
for (i, &ns) in self.normalized_shape.iter().enumerate() {
assert_eq!(
shape[start_dim + i],
ns,
"Input shape doesn't match normalized_shape at dim {i}"
);
}
if self.elementwise_affine {
crate::nn::functional::rms_norm(input, &self.weight, self.eps)
} else {
let batch_dims: usize = shape[..start_dim].iter().product();
let input_data = input.data();
let mut output_data = vec![0.0; input_data.len()];
for b in 0..batch_dims {
let offset = b * norm_size;
let slice = &input_data[offset..offset + norm_size];
let mean_sq: f32 = slice.iter().map(|&x| x * x).sum::<f32>() / norm_size as f32;
let rms_inv = 1.0 / (mean_sq + self.eps).sqrt();
for i in 0..norm_size {
output_data[offset + i] = slice[i] * rms_inv;
}
}
Tensor::new(&output_data, shape)
}
}
fn parameters(&self) -> Vec<&Tensor> {
if self.elementwise_affine {
vec![&self.weight]
} else {
vec![]
}
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
if self.elementwise_affine {
vec![&mut self.weight]
} else {
vec![]
}
}
}
#[derive(Debug)]
pub struct InstanceNorm {
inner: GroupNorm,
}
impl InstanceNorm {
#[must_use]
pub fn new(num_channels: usize) -> Self {
Self {
inner: GroupNorm::new(num_channels, num_channels),
}
}
#[must_use]
pub fn without_affine(num_channels: usize) -> Self {
Self {
inner: GroupNorm::without_affine(num_channels, num_channels),
}
}
}
impl Module for InstanceNorm {
fn forward(&self, input: &Tensor) -> Tensor {
self.inner.forward(input)
}
fn parameters(&self) -> Vec<&Tensor> {
self.inner.parameters()
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
self.inner.parameters_mut()
}
}
#[cfg(test)]
#[path = "tests.rs"]
mod tests;