use crate::autograd::Variable;
use crate::nn::Module;
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
use std::fmt::Debug;
use std::iter::Sum;
use std::sync::{Arc, RwLock};
#[derive(Debug)]
pub struct BatchNorm1d<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
num_features: usize,
weight: Variable<T>,
bias: Variable<T>,
running_mean: Arc<RwLock<Tensor<T>>>,
running_var: Arc<RwLock<Tensor<T>>>,
momentum: T,
eps: T,
training: Arc<RwLock<bool>>,
}
impl<T> BatchNorm1d<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ From<f32>
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ num_traits::FromPrimitive,
{
pub fn new(
num_features: usize,
eps: Option<T>,
momentum: Option<T>,
affine: Option<bool>,
) -> Self {
let eps = eps.unwrap_or_else(|| <T as From<f32>>::from(1e-5f32));
let momentum = momentum.unwrap_or_else(|| <T as From<f32>>::from(0.1f32));
let affine = affine.unwrap_or(true);
let weight = if affine {
Variable::new(Tensor::ones(&[num_features]), true)
} else {
Variable::new(Tensor::ones(&[num_features]), false)
};
let bias = if affine {
Variable::new(Tensor::zeros(&[num_features]), true)
} else {
Variable::new(Tensor::zeros(&[num_features]), false)
};
let running_mean = Arc::new(RwLock::new(Tensor::zeros(&[num_features])));
let running_var = Arc::new(RwLock::new(Tensor::ones(&[num_features])));
BatchNorm1d {
num_features,
weight,
bias,
running_mean,
running_var,
momentum,
eps,
training: Arc::new(RwLock::new(true)),
}
}
pub fn train(&self) {
if let Ok(mut training) = self.training.write() {
*training = true;
}
}
pub fn eval(&self) {
if let Ok(mut training) = self.training.write() {
*training = false;
}
}
pub fn is_training(&self) -> bool {
self.training
.read()
.unwrap_or_else(|_| panic!("Failed to read training mode"))
.clone()
}
pub fn forward(&self, input: &Variable<T>) -> Variable<T> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_shape = input_data.shape();
if input_shape.len() != 2 || input_shape[1] != self.num_features {
panic!(
"Expected 2D input with {} features, got shape {:?}",
self.num_features, input_shape
);
}
let normalized_tensor = if self.is_training() {
self.normalize_training(&input_data)
} else {
self.normalize_eval(&input_data)
};
let requires_grad =
input.requires_grad() || self.weight.requires_grad() || self.bias.requires_grad();
Variable::new(normalized_tensor, requires_grad)
}
fn normalize_training(&self, input: &Tensor<T>) -> Tensor<T> {
let input_shape = input.shape();
let batch_size = input_shape[0];
let num_features = input_shape[1];
let (batch_mean, batch_var) =
self.compute_batch_statistics(input, batch_size, num_features);
self.update_running_statistics(&batch_mean, &batch_var);
self.apply_normalization(input, &batch_mean, &batch_var)
}
fn normalize_eval(&self, input: &Tensor<T>) -> Tensor<T> {
let running_mean_lock = self.running_mean.read().unwrap();
let running_var_lock = self.running_var.read().unwrap();
self.apply_normalization(input, &running_mean_lock, &running_var_lock)
}
fn apply_normalization(
&self,
input: &Tensor<T>,
mean: &Tensor<T>,
var: &Tensor<T>,
) -> Tensor<T> {
let weight_binding = self.weight.data();
let weight_data = weight_binding.read().unwrap();
let bias_binding = self.bias.data();
let bias_data = bias_binding.read().unwrap();
let input_array = input.as_array();
let mean_array = mean.as_array();
let var_array = var.as_array();
let weight_array = weight_data.as_array();
let bias_array = bias_data.as_array();
let input_shape = input.shape();
let batch_size = input_shape[0];
let num_features = input_shape[1];
let mut output_data = Vec::with_capacity(batch_size * num_features);
for b in 0..batch_size {
for f in 0..num_features {
let x = input_array[[b, f]];
let mu = mean_array[f];
let sigma2 = var_array[f];
let ten = T::from_f32(10.0).unwrap();
let eps_adjusted = if sigma2 < self.eps * ten {
self.eps * ten
} else {
self.eps
};
let normalized = (x - mu) / (sigma2 + eps_adjusted).sqrt();
let output_val = weight_array[f] * normalized + bias_array[f];
output_data.push(output_val);
}
}
Tensor::from_vec(output_data, input_shape.to_vec())
}
fn compute_batch_statistics(
&self,
input: &Tensor<T>,
batch_size: usize,
num_features: usize,
) -> (Tensor<T>, Tensor<T>) {
let input_array = input.as_array();
let mut mean_vec = vec![T::zero(); num_features];
let mut var_vec = vec![T::zero(); num_features];
for f in 0..num_features {
let mut mean = T::zero();
let mut m2 = T::zero();
for b in 0..batch_size {
let x = input_array[[b, f]];
let delta = x - mean;
mean = mean + delta / T::from_usize(b + 1).unwrap();
let delta2 = x - mean;
m2 = m2 + delta * delta2;
}
mean_vec[f] = mean;
let variance = if batch_size > 1 {
m2 / T::from_usize(batch_size).unwrap()
} else {
T::one() };
let bias_corrected_var = if batch_size > 1 {
variance * T::from_usize(batch_size).unwrap()
/ T::from_usize(batch_size - 1).unwrap()
} else {
variance
};
let min_var_threshold = self.eps * T::from_f32(0.1).unwrap();
var_vec[f] = bias_corrected_var.max(min_var_threshold);
}
let mean_tensor = Tensor::from_vec(mean_vec, vec![num_features]);
let var_tensor = Tensor::from_vec(var_vec, vec![num_features]);
(mean_tensor, var_tensor)
}
fn update_running_statistics(&self, batch_mean: &Tensor<T>, batch_var: &Tensor<T>) {
if let (Ok(mut running_mean), Ok(mut running_var)) =
(self.running_mean.write(), self.running_var.write())
{
let batch_mean_array = batch_mean.as_array();
let batch_var_array = batch_var.as_array();
let running_mean_array = running_mean.as_array_mut();
let running_var_array = running_var.as_array_mut();
let momentum = self.momentum;
let one_minus_momentum = T::one() - momentum;
for i in 0..self.num_features {
running_mean_array[i] =
one_minus_momentum * running_mean_array[i] + momentum * batch_mean_array[i];
running_var_array[i] =
one_minus_momentum * running_var_array[i] + momentum * batch_var_array[i];
}
}
}
pub fn parameters(&self) -> Vec<Variable<T>> {
vec![self.weight.clone(), self.bias.clone()]
}
pub fn num_features(&self) -> usize {
self.num_features
}
pub fn eps(&self) -> T {
self.eps
}
pub fn momentum(&self) -> T {
self.momentum
}
pub fn running_mean(&self) -> Tensor<T> {
self.running_mean.read().unwrap().clone()
}
pub fn running_var(&self) -> Tensor<T> {
self.running_var.read().unwrap().clone()
}
}
impl<T> Module<T> for BatchNorm1d<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ From<f32>
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug)]
pub struct BatchNorm2d<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
num_features: usize,
weight: Variable<T>,
bias: Variable<T>,
running_mean: Arc<RwLock<Tensor<T>>>,
running_var: Arc<RwLock<Tensor<T>>>,
momentum: T,
eps: T,
training: Arc<RwLock<bool>>,
}
impl<T> BatchNorm2d<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ From<f32>
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ num_traits::FromPrimitive,
{
pub fn new(
num_features: usize,
eps: Option<T>,
momentum: Option<T>,
affine: Option<bool>,
) -> Self {
let eps = eps.unwrap_or_else(|| <T as From<f32>>::from(1e-5f32));
let momentum = momentum.unwrap_or_else(|| <T as From<f32>>::from(0.1f32));
let affine = affine.unwrap_or(true);
let weight = if affine {
Variable::new(Tensor::ones(&[num_features]), true)
} else {
Variable::new(Tensor::ones(&[num_features]), false)
};
let bias = if affine {
Variable::new(Tensor::zeros(&[num_features]), true)
} else {
Variable::new(Tensor::zeros(&[num_features]), false)
};
let running_mean = Arc::new(RwLock::new(Tensor::zeros(&[num_features])));
let running_var = Arc::new(RwLock::new(Tensor::ones(&[num_features])));
BatchNorm2d {
num_features,
weight,
bias,
running_mean,
running_var,
momentum,
eps,
training: Arc::new(RwLock::new(true)),
}
}
pub fn train(&self) {
if let Ok(mut training) = self.training.write() {
*training = true;
}
}
pub fn eval(&self) {
if let Ok(mut training) = self.training.write() {
*training = false;
}
}
pub fn is_training(&self) -> bool {
self.training
.read()
.unwrap_or_else(|_| panic!("Failed to read training mode"))
.clone()
}
pub fn forward(&self, input: &Variable<T>) -> Variable<T> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_shape = input_data.shape();
if input_shape.len() != 4 || input_shape[1] != self.num_features {
panic!(
"Expected 4D input with {} channels, got shape {:?}",
self.num_features, input_shape
);
}
let normalized_tensor = if self.is_training() {
self.normalize_training_2d(&input_data)
} else {
self.normalize_eval_2d(&input_data)
};
let requires_grad =
input.requires_grad() || self.weight.requires_grad() || self.bias.requires_grad();
Variable::new(normalized_tensor, requires_grad)
}
fn normalize_training_2d(&self, input: &Tensor<T>) -> Tensor<T> {
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let height = input_shape[2];
let width = input_shape[3];
let (channel_mean, channel_var) =
self.compute_channel_statistics(input, batch_size, channels, height, width);
self.update_running_statistics_2d(&channel_mean, &channel_var);
self.apply_channel_normalization_2d(input, &channel_mean, &channel_var)
}
fn normalize_eval_2d(&self, input: &Tensor<T>) -> Tensor<T> {
let running_mean_lock = self.running_mean.read().unwrap();
let running_var_lock = self.running_var.read().unwrap();
self.apply_channel_normalization_2d(input, &running_mean_lock, &running_var_lock)
}
fn compute_channel_statistics(
&self,
input: &Tensor<T>,
batch_size: usize,
channels: usize,
height: usize,
width: usize,
) -> (Tensor<T>, Tensor<T>) {
let input_array = input.as_array();
let spatial_size = height * width;
let total_elements_per_channel = batch_size * spatial_size;
let mut mean_vec = vec![T::zero(); channels];
let mut var_vec = vec![T::zero(); channels];
for c in 0..channels {
let mut mean = T::zero();
let mut m2 = T::zero();
let mut count = 0;
for b in 0..batch_size {
for h in 0..height {
for w in 0..width {
let x = input_array[[b, c, h, w]];
count += 1;
let delta = x - mean;
mean = mean + delta / T::from_usize(count).unwrap();
let delta2 = x - mean;
m2 = m2 + delta * delta2;
}
}
}
mean_vec[c] = mean;
let variance = if total_elements_per_channel > 1 {
m2 / T::from_usize(total_elements_per_channel).unwrap()
} else {
T::one()
};
let bias_corrected_var = if total_elements_per_channel > 1 {
variance * T::from_usize(total_elements_per_channel).unwrap()
/ T::from_usize(total_elements_per_channel - 1).unwrap()
} else {
variance
};
let min_var_threshold = self.eps * T::from_f32(0.1).unwrap();
var_vec[c] = bias_corrected_var.max(min_var_threshold);
}
let mean_tensor = Tensor::from_vec(mean_vec, vec![channels]);
let var_tensor = Tensor::from_vec(var_vec, vec![channels]);
(mean_tensor, var_tensor)
}
fn apply_channel_normalization_2d(
&self,
input: &Tensor<T>,
mean: &Tensor<T>,
var: &Tensor<T>,
) -> Tensor<T> {
let weight_binding = self.weight.data();
let weight_data = weight_binding.read().unwrap();
let bias_binding = self.bias.data();
let bias_data = bias_binding.read().unwrap();
let input_array = input.as_array();
let mean_array = mean.as_array();
let var_array = var.as_array();
let weight_array = weight_data.as_array();
let bias_array = bias_data.as_array();
let input_shape = input.shape();
let batch_size = input_shape[0];
let channels = input_shape[1];
let height = input_shape[2];
let width = input_shape[3];
let mut output_data = Vec::with_capacity(batch_size * channels * height * width);
for b in 0..batch_size {
for c in 0..channels {
let mu = mean_array[c];
let sigma2 = var_array[c];
let gamma = weight_array[c];
let beta = bias_array[c];
let ten = T::from_f32(10.0).unwrap();
let eps_adjusted = if sigma2 < self.eps * ten {
self.eps * ten
} else {
self.eps
};
let inv_std = T::one() / (sigma2 + eps_adjusted).sqrt();
for h in 0..height {
for w in 0..width {
let x = input_array[[b, c, h, w]];
let normalized = (x - mu) * inv_std;
let output_val = gamma * normalized + beta;
output_data.push(output_val);
}
}
}
}
Tensor::from_vec(output_data, input_shape.to_vec())
}
fn update_running_statistics_2d(&self, batch_mean: &Tensor<T>, batch_var: &Tensor<T>) {
if let (Ok(mut running_mean), Ok(mut running_var)) =
(self.running_mean.write(), self.running_var.write())
{
let batch_mean_array = batch_mean.as_array();
let batch_var_array = batch_var.as_array();
let running_mean_array = running_mean.as_array_mut();
let running_var_array = running_var.as_array_mut();
let momentum = self.momentum;
let one_minus_momentum = T::one() - momentum;
for i in 0..self.num_features {
running_mean_array[i] =
one_minus_momentum * running_mean_array[i] + momentum * batch_mean_array[i];
running_var_array[i] =
one_minus_momentum * running_var_array[i] + momentum * batch_var_array[i];
}
}
}
pub fn parameters(&self) -> Vec<Variable<T>> {
vec![self.weight.clone(), self.bias.clone()]
}
pub fn num_features(&self) -> usize {
self.num_features
}
pub fn eps(&self) -> T {
self.eps
}
pub fn momentum(&self) -> T {
self.momentum
}
pub fn running_mean(&self) -> Tensor<T> {
self.running_mean.read().unwrap().clone()
}
pub fn running_var(&self) -> Tensor<T> {
self.running_var.read().unwrap().clone()
}
}
impl<T> Module<T> for BatchNorm2d<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ From<f32>
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batchnorm1d_creation() {
let bn = BatchNorm1d::<f32>::new(10, None, None, None);
assert_eq!(bn.num_features(), 10);
assert!(bn.is_training());
}
#[test]
fn test_batchnorm1d_forward() {
let bn = BatchNorm1d::<f32>::new(3, None, None, None);
let input_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let input = Variable::new(Tensor::from_vec(input_data, vec![2, 3]), false);
let output = bn.forward(&input);
let output_binding = output.data();
let output_data = output_binding.read().unwrap();
let output_shape = output_data.shape();
assert_eq!(output_shape, &[2, 3]);
}
#[test]
fn test_batchnorm1d_eval_mode() {
let bn = BatchNorm1d::<f32>::new(3, None, None, None);
bn.eval();
assert!(!bn.is_training());
let input_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let input = Variable::new(Tensor::from_vec(input_data, vec![2, 3]), false);
let output = bn.forward(&input);
let output_binding = output.data();
let output_data = output_binding.read().unwrap();
assert_eq!(output_data.shape(), &[2, 3]);
}
#[test]
fn test_batchnorm2d_creation() {
let bn = BatchNorm2d::<f32>::new(16, None, None, None);
assert_eq!(bn.num_features(), 16);
assert!(bn.is_training());
}
#[test]
fn test_batchnorm2d_forward() {
let bn = BatchNorm2d::<f32>::new(2, None, None, None);
let input_data: Vec<f32> = (0..36).map(|i| i as f32).collect();
let input = Variable::new(Tensor::from_vec(input_data, vec![2, 2, 3, 3]), false);
let output = bn.forward(&input);
let output_binding = output.data();
let output_data = output_binding.read().unwrap();
let output_shape = output_data.shape();
assert_eq!(output_shape, &[2, 2, 3, 3]);
}
#[test]
fn test_batchnorm_parameters() {
let bn1d = BatchNorm1d::<f32>::new(5, None, None, None);
let params = bn1d.parameters();
assert_eq!(params.len(), 2);
let weight_binding = params[0].data();
let weight_data = weight_binding.read().unwrap();
let weight_shape = weight_data.shape();
let bias_binding = params[1].data();
let bias_data = bias_binding.read().unwrap();
let bias_shape = bias_data.shape();
assert_eq!(weight_shape, &[5]);
assert_eq!(bias_shape, &[5]);
}
}