use scivex_core::{Float, Tensor};
use crate::error::Result;
use crate::variable::Variable;
use super::Layer;
pub struct BatchNorm2d<T: Float> {
gamma: Variable<T>,
beta: Variable<T>,
running_mean: Vec<T>,
running_var: Vec<T>,
eps: T,
training: bool,
}
impl<T: Float> BatchNorm2d<T> {
pub fn new(num_channels: usize) -> Self {
let gamma = Variable::new(Tensor::ones(vec![num_channels]), true);
let beta = Variable::new(Tensor::zeros(vec![num_channels]), true);
Self {
gamma,
beta,
running_mean: vec![T::zero(); num_channels],
running_var: vec![T::one(); num_channels],
eps: T::from_f64(1e-5),
training: true,
}
}
}
fn channel_stats<T: Float>(
x_slice: &[T],
batch: usize,
channels: usize,
spatial: usize,
) -> (Vec<T>, Vec<T>) {
let count = T::from_usize(batch * spatial);
let mut mean = vec![T::zero(); channels];
let mut var = vec![T::zero(); channels];
for c in 0..channels {
let mut sum = T::zero();
for b_idx in 0..batch {
let base = b_idx * channels * spatial + c * spatial;
for s in 0..spatial {
sum += x_slice[base + s];
}
}
mean[c] = sum / count;
let mut sq_sum = T::zero();
for b_idx in 0..batch {
let base = b_idx * channels * spatial + c * spatial;
for s in 0..spatial {
let diff = x_slice[base + s] - mean[c];
sq_sum += diff * diff;
}
}
var[c] = sq_sum / count;
}
(mean, var)
}
impl<T: Float> Layer<T> for BatchNorm2d<T> {
fn forward(&self, x: &Variable<T>) -> Result<Variable<T>> {
let x_data = x.data();
let shape = x_data.shape().to_vec();
let batch = shape[0];
let channels = shape[1];
let height = shape[2];
let width = shape[3];
let spatial = height * width;
let x_slice = x_data.as_slice();
let gamma_data = self.gamma.data();
let beta_data = self.beta.data();
let g_slice = gamma_data.as_slice();
let b_slice = beta_data.as_slice();
let (mean, var) = if self.training {
channel_stats(x_slice, batch, channels, spatial)
} else {
(self.running_mean.clone(), self.running_var.clone())
};
let total = batch * channels * spatial;
let mut out_data = vec![T::zero(); total];
let mut x_norm_data = vec![T::zero(); total];
let mut inv_std = vec![T::zero(); channels];
for c in 0..channels {
inv_std[c] = T::one() / (var[c] + self.eps).sqrt();
}
for b_idx in 0..batch {
for c in 0..channels {
let base = b_idx * channels * spatial + c * spatial;
for s in 0..spatial {
let idx = base + s;
let xn = (x_slice[idx] - mean[c]) * inv_std[c];
x_norm_data[idx] = xn;
out_data[idx] = g_slice[c] * xn + b_slice[c];
}
}
}
let out = Tensor::from_vec(out_data, shape.clone())?;
let x_norm = Tensor::from_vec(x_norm_data, shape.clone())?;
let inv_std_t = Tensor::from_vec(inv_std, vec![channels])?;
Ok(Variable::from_op(
out,
vec![x.clone(), self.gamma.clone(), self.beta.clone()],
Box::new(move |g: &Tensor<T>| {
let g_s = g.as_slice();
let xn_s = x_norm.as_slice();
let inv_s = inv_std_t.as_slice();
let gs = gamma_data.as_slice();
let mut grad_gamma = vec![T::zero(); channels];
let mut grad_beta = vec![T::zero(); channels];
let mut grad_x = vec![T::zero(); batch * channels * spatial];
let count = T::from_usize(batch * spatial);
for c in 0..channels {
let mut g_sum = T::zero();
let mut gxn_sum = T::zero();
for b_idx in 0..batch {
let base = b_idx * channels * spatial + c * spatial;
for s in 0..spatial {
let idx = base + s;
grad_gamma[c] += g_s[idx] * xn_s[idx];
grad_beta[c] += g_s[idx];
g_sum += g_s[idx];
gxn_sum += g_s[idx] * xn_s[idx];
}
}
for b_idx in 0..batch {
let base = b_idx * channels * spatial + c * spatial;
for s in 0..spatial {
let idx = base + s;
grad_x[idx] = gs[c]
* inv_s[c]
* (g_s[idx] - g_sum / count - xn_s[idx] * gxn_sum / count);
}
}
}
vec![
Tensor::from_vec(grad_x, shape.clone())
.expect("grad shape matches forward pass"),
Tensor::from_vec(grad_gamma, vec![channels])
.expect("gamma grad length matches channels"),
Tensor::from_vec(grad_beta, vec![channels])
.expect("beta grad length matches channels"),
]
}),
))
}
fn parameters(&self) -> Vec<Variable<T>> {
vec![self.gamma.clone(), self.beta.clone()]
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
use scivex_core::Tensor;
#[test]
fn test_batchnorm2d_output_shape() {
let bn = BatchNorm2d::<f64>::new(3);
let data: Vec<f64> = (0..96).map(|i| f64::from(i) * 0.1).collect();
let x = Variable::new(Tensor::from_vec(data, vec![2, 3, 4, 4]).unwrap(), true);
let y = bn.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 3, 4, 4]);
}
#[test]
fn test_batchnorm2d_normalized_output() {
let bn = BatchNorm2d::<f64>::new(2);
let data: Vec<f64> = (0..24).map(|i| (f64::from(i)) * 0.5 + 1.0).collect();
let x = Variable::new(Tensor::from_vec(data, vec![3, 2, 2, 2]).unwrap(), true);
let y = bn.forward(&x).unwrap();
let y_data = y.data();
let y_s = y_data.as_slice();
let spatial = 4; let channels = 2;
let batch = 3;
for c in 0..channels {
let mut sum = 0.0;
for b_idx in 0..batch {
let base = b_idx * channels * spatial + c * spatial;
for s in 0..spatial {
sum += y_s[base + s];
}
}
let mean = sum / (batch * spatial) as f64;
assert!(
mean.abs() < 1e-5,
"channel {c} mean was {mean}, expected ~0"
);
}
}
#[test]
fn test_batchnorm2d_parameters() {
let bn = BatchNorm2d::<f64>::new(5);
assert_eq!(bn.parameters().len(), 2);
}
#[test]
fn test_batchnorm2d_backward() {
let bn = BatchNorm2d::<f64>::new(2);
let data: Vec<f64> = (0..16).map(|i| (f64::from(i)) * 0.3 + 0.5).collect();
let x = Variable::new(Tensor::from_vec(data, vec![2, 2, 2, 2]).unwrap(), true);
let y = bn.forward(&x).unwrap();
let y_data = y.data();
let y_s = y_data.as_slice();
let total: f64 = y_s.iter().copied().sum();
let scalar = Variable::from_op(
Tensor::from_vec(vec![total], vec![1]).unwrap(),
vec![y.clone()],
Box::new(move |_g: &Tensor<f64>| {
vec![Tensor::ones(vec![2, 2, 2, 2])]
}),
);
scalar.backward();
let x_grad = x.grad();
assert!(x_grad.is_some(), "x should have gradients after backward");
let g = x_grad.unwrap();
assert_eq!(g.shape(), &[2, 2, 2, 2]);
let params = bn.parameters();
for (i, p) in params.iter().enumerate() {
let pg = p.grad();
assert!(
pg.is_some(),
"parameter {i} should have gradient after backward"
);
}
}
}