use zyx::{DType, Tensor, ZyxError};
use zyx_derive::Module;
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct BatchNorm {
pub eps: f32,
pub momentum: f32,
pub track_running_stats: bool,
pub weight: Option<Tensor>,
pub bias: Option<Tensor>,
pub running_mean: Tensor,
pub running_var: Tensor,
pub num_batches_tracked: Tensor,
}
impl BatchNorm {
pub fn new(self, num_features: u64, dtype: DType) -> BatchNorm {
BatchNorm {
eps: 1e-5,
momentum: 0.1,
track_running_stats: true,
weight: Some(Tensor::ones(num_features, dtype)),
bias: Some(Tensor::zeros(num_features, dtype)),
running_mean: Tensor::zeros(num_features, dtype),
running_var: Tensor::ones(num_features, dtype),
num_batches_tracked: Tensor::zeros(1, dtype),
}
}
pub fn forward(&mut self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
let batch_mean;
let batch_invstd;
let x = x.into();
if Tensor::training() {
batch_mean = x.mean([0, 2, 3])?;
let y = &x - batch_mean.reshape([1, batch_mean.numel(), 1, 1])?;
let batch_var = (&y * &y).mean([0, 2, 3])?;
batch_invstd = (self
.running_var
.reshape([1, self.running_var.numel(), 1, 1])?
.expand(x.shape())?
+ self.eps)
.rsqrt();
if self.track_running_stats {
self.running_mean =
&self.running_mean * (1.0 - self.momentum) + &batch_mean * self.momentum;
self.running_var = &self.running_var * (1.0 - self.momentum)
+ batch_var * self.momentum * y.numel() as f32
/ (y.numel() - y.shape()[1]) as f32;
self.num_batches_tracked = &self.num_batches_tracked + 1;
}
} else {
batch_mean = self.running_mean.clone();
batch_invstd = (self
.running_var
.reshape([1, self.running_var.numel(), 1, 1])?
.expand(x.shape())?
+ self.eps)
.rsqrt()
}
let shape = [1, batch_mean.numel(), 1, 1];
let mut x = x - batch_mean.reshape(shape)?;
if let Some(weight) = &self.weight {
x = weight.reshape(shape)? * x;
}
x = x * if batch_invstd.rank() == 1 {
batch_invstd.reshape(shape)?
} else {
batch_invstd
};
if let Some(bias) = &self.bias {
Ok(x + bias.reshape(shape)?)
} else {
Ok(x)
}
}
}