use crate::{simd, CnnError, CnnResult, Tensor};
use super::Layer;
pub type BatchNorm2d = BatchNorm;
#[derive(Debug, Clone)]
pub struct BatchNorm {
num_features: usize,
gamma: Vec<f32>,
beta: Vec<f32>,
running_mean: Vec<f32>,
running_var: Vec<f32>,
epsilon: f32,
#[allow(dead_code)]
momentum: f32,
}
impl BatchNorm {
pub fn new(num_features: usize) -> Self {
Self {
num_features,
gamma: vec![1.0; num_features],
beta: vec![0.0; num_features],
running_mean: vec![0.0; num_features],
running_var: vec![1.0; num_features],
epsilon: 1e-5,
momentum: 0.1,
}
}
pub fn with_epsilon(num_features: usize, epsilon: f32) -> Self {
let mut bn = Self::new(num_features);
bn.epsilon = epsilon;
bn
}
pub fn set_params(&mut self, gamma: Vec<f32>, beta: Vec<f32>) -> CnnResult<()> {
if gamma.len() != self.num_features || beta.len() != self.num_features {
return Err(CnnError::invalid_shape(
format!("num_features={}", self.num_features),
format!("gamma={}, beta={}", gamma.len(), beta.len()),
));
}
self.gamma = gamma;
self.beta = beta;
Ok(())
}
pub fn set_running_stats(&mut self, mean: Vec<f32>, var: Vec<f32>) -> CnnResult<()> {
if mean.len() != self.num_features || var.len() != self.num_features {
return Err(CnnError::invalid_shape(
format!("num_features={}", self.num_features),
format!("mean={}, var={}", mean.len(), var.len()),
));
}
self.running_mean = mean;
self.running_var = var;
Ok(())
}
pub fn num_features(&self) -> usize {
self.num_features
}
pub fn gamma(&self) -> &[f32] {
&self.gamma
}
pub fn beta(&self) -> &[f32] {
&self.beta
}
pub fn running_mean(&self) -> &[f32] {
&self.running_mean
}
pub fn running_var(&self) -> &[f32] {
&self.running_var
}
}
impl Layer for BatchNorm {
fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
let shape = input.shape();
if shape.len() != 4 {
return Err(CnnError::invalid_shape(
"4D tensor (NHWC)",
format!("{}D tensor", shape.len()),
));
}
let channels = shape[3];
if channels != self.num_features {
return Err(CnnError::invalid_shape(
format!("{} channels", self.num_features),
format!("{} channels", channels),
));
}
let mut output = Tensor::zeros(shape);
simd::batch_norm_simd(
input.data(),
output.data_mut(),
&self.gamma,
&self.beta,
&self.running_mean,
&self.running_var,
self.epsilon,
self.num_features,
);
Ok(output)
}
fn name(&self) -> &'static str {
"BatchNorm"
}
fn num_params(&self) -> usize {
self.num_features * 2
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_norm_creation() {
let bn = BatchNorm::new(64);
assert_eq!(bn.num_features(), 64);
assert_eq!(bn.gamma().len(), 64);
assert_eq!(bn.beta().len(), 64);
assert_eq!(bn.num_params(), 128);
}
#[test]
fn test_batch_norm_forward() {
let bn = BatchNorm::new(4);
let input = Tensor::ones(&[1, 8, 8, 4]);
let output = bn.forward(&input).unwrap();
assert_eq!(output.shape(), input.shape());
for &val in output.data() {
assert!((val - 1.0).abs() < 0.01);
}
}
#[test]
fn test_batch_norm_normalization() {
let mut bn = BatchNorm::new(2);
bn.set_running_stats(vec![1.0, 2.0], vec![1.0, 4.0]).unwrap();
let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 1, 2]).unwrap();
let output = bn.forward(&input).unwrap();
assert!(output.data()[0].abs() < 0.01);
assert!(output.data()[1].abs() < 0.01);
assert!((output.data()[2] - 2.0).abs() < 0.01);
assert!((output.data()[3] - 1.0).abs() < 0.01);
}
#[test]
fn test_batch_norm_invalid_shape() {
let bn = BatchNorm::new(4);
let input = Tensor::ones(&[1, 8, 8, 8]);
let result = bn.forward(&input);
assert!(result.is_err());
}
}