use std::cell::Cell;
use crate::autograd::Variable;
use crate::tensor::{Result, Tensor, TensorError, TensorOptions};
use super::buffer::Buffer;
use super::parameter::Parameter;
use super::Module;
pub struct BatchNorm {
pub weight: Parameter, pub bias: Parameter, running_mean: Buffer,
running_var: Buffer,
#[allow(dead_code)]
num_features: i64,
eps: f64,
momentum: f64,
training: Cell<bool>,
}
impl BatchNorm {
pub fn new(num_features: i64) -> Result<Self> {
Self::on_device(num_features, crate::tensor::Device::CPU)
}
pub fn on_device(num_features: i64, device: crate::tensor::Device) -> Result<Self> {
let opts = TensorOptions { dtype: crate::tensor::DType::Float32, device };
let weight = Variable::new(Tensor::ones(&[num_features], opts)?, true);
let bias = Variable::new(Tensor::zeros(&[num_features], opts)?, true);
Ok(BatchNorm {
weight: Parameter { variable: weight, name: "weight".into() },
bias: Parameter { variable: bias, name: "bias".into() },
running_mean: Buffer::new(Tensor::zeros(&[num_features], opts)?, "running_mean"),
running_var: Buffer::new(Tensor::ones(&[num_features], opts)?, "running_var"),
num_features,
eps: 1e-5,
momentum: 0.1,
training: Cell::new(true),
})
}
}
pub struct BatchNorm2d {
inner: BatchNorm,
}
impl BatchNorm2d {
pub fn new(num_channels: i64) -> Result<Self> {
Ok(Self { inner: BatchNorm::new(num_channels)? })
}
pub fn on_device(num_channels: i64, device: crate::tensor::Device) -> Result<Self> {
Ok(Self { inner: BatchNorm::on_device(num_channels, device)? })
}
}
impl Module for BatchNorm2d {
fn name(&self) -> &str { "batchnorm2d" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let shape = input.shape();
if shape.len() != 4 {
return Err(TensorError::new(
"BatchNorm2d: input must be 4D [B, C, H, W]"
));
}
self.inner.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
self.inner.parameters()
}
fn buffers(&self) -> Vec<Buffer> {
self.inner.buffers()
}
fn set_training(&self, training: bool) {
self.inner.set_training(training);
}
fn move_to_device(&self, device: crate::tensor::Device) {
self.inner.move_to_device(device);
}
}
impl Module for BatchNorm {
fn name(&self) -> &str { "batchnorm" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let training = self.training.get();
if training {
let batch_size = input.shape()[0];
if batch_size < 2 {
return Err(TensorError::new(
"BatchNorm requires batch_size >= 2 in training mode \
(Bessel's correction divides by batch_size-1)"
));
}
}
let rm = self.running_mean.get();
let rv = self.running_var.get();
let result = input.data().batch_norm(
Some(&self.weight.variable.data()),
Some(&self.bias.variable.data()),
Some(&rm), Some(&rv),
training, self.momentum, self.eps,
)?;
Ok(Variable::wrap(result))
}
fn parameters(&self) -> Vec<Parameter> {
vec![self.weight.clone(), self.bias.clone()]
}
fn buffers(&self) -> Vec<Buffer> {
vec![self.running_mean.clone(), self.running_var.clone()]
}
fn set_training(&self, training: bool) {
self.training.set(training);
}
fn move_to_device(&self, device: crate::tensor::Device) {
let _ = self.running_mean.to_device(device);
let _ = self.running_var.to_device(device);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Tensor, test_device, test_opts};
#[test]
fn test_batchnorm_forward_training() {
let bn = BatchNorm::on_device(4, test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[8, 4], test_opts()).unwrap(), false,
);
let y = bn.forward(&x).unwrap();
assert_eq!(y.shape(), vec![8, 4]);
}
#[test]
fn test_batchnorm_eval_mode() {
let bn = BatchNorm::on_device(4, test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[8, 4], test_opts()).unwrap(), false,
);
let _ = bn.forward(&x).unwrap();
bn.set_training(false);
let x_single = Variable::new(
Tensor::randn(&[1, 4], test_opts()).unwrap(), false,
);
let y = bn.forward(&x_single).unwrap();
assert_eq!(y.shape(), vec![1, 4]);
}
#[test]
fn test_batchnorm_training_requires_batch_ge_2() {
let bn = BatchNorm::on_device(4, test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[1, 4], test_opts()).unwrap(), false,
);
assert!(bn.forward(&x).is_err());
}
#[test]
fn test_batchnorm_running_stats_update() {
let bn = BatchNorm::on_device(2, test_device()).unwrap();
let rm_before = bn.buffers()[0].get().to_f32_vec().unwrap();
assert!(rm_before.iter().all(|&v| v.abs() < 1e-6));
let x = Variable::new(
Tensor::from_f32(&[10.0, 20.0, 12.0, 22.0, 11.0, 21.0, 9.0, 19.0],
&[4, 2], test_device()).unwrap(),
false,
);
let _ = bn.forward(&x).unwrap();
let rm_after = bn.buffers()[0].get().to_f32_vec().unwrap();
assert!(rm_after[0].abs() > 0.5, "running_mean should have updated: {}", rm_after[0]);
}
#[test]
fn test_batchnorm_parameters() {
let bn = BatchNorm::on_device(8, test_device()).unwrap();
assert_eq!(bn.parameters().len(), 2); assert_eq!(bn.buffers().len(), 2); }
#[test]
fn test_batchnorm_gradient() {
let bn = BatchNorm::on_device(3, test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[4, 3], test_opts()).unwrap(), true,
);
let y = bn.forward(&x).unwrap().sum().unwrap();
y.backward().unwrap();
assert!(x.grad().is_some());
}
#[test]
fn test_batchnorm2d_forward() {
let bn = BatchNorm2d::on_device(3, test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[2, 3, 4, 4], test_opts()).unwrap(), false,
);
let y = bn.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 3, 4, 4]);
}
#[test]
fn test_batchnorm2d_rejects_non_4d() {
let bn = BatchNorm2d::on_device(3, test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[2, 3], test_opts()).unwrap(), false,
);
assert!(bn.forward(&x).is_err());
}
#[test]
fn test_batchnorm2d_training_eval_differ() {
let bn = BatchNorm2d::on_device(2, test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[4, 2, 3, 3], test_opts()).unwrap(), false,
);
let train_out = bn.forward(&x).unwrap().data().to_f32_vec().unwrap();
bn.set_training(false);
let eval_out = bn.forward(&x).unwrap().data().to_f32_vec().unwrap();
let diff: f32 = train_out.iter().zip(&eval_out)
.map(|(a, b)| (a - b).abs()).sum();
assert!(diff > 0.01, "train and eval outputs should differ");
}
}