use crate::Tensor;
use std::borrow::Borrow;
#[derive(Debug, Clone, Copy)]
pub struct BatchNormConfig {
pub cudnn_enabled: bool,
pub eps: f64,
pub momentum: f64,
pub ws_init: super::Init,
pub bs_init: super::Init,
}
impl Default for BatchNormConfig {
fn default() -> Self {
BatchNormConfig {
cudnn_enabled: true,
eps: 1e-5,
momentum: 0.1,
ws_init: super::Init::Uniform { lo: 0., up: 1. },
bs_init: super::Init::Const(0.),
}
}
}
#[derive(Debug)]
pub struct BatchNorm {
config: BatchNormConfig,
pub running_mean: Tensor,
pub running_var: Tensor,
pub ws: Tensor,
pub bs: Tensor,
pub nd: usize,
}
fn batch_norm<'a, T: Borrow<super::Path<'a>>>(
vs: T,
nd: usize,
out_dim: i64,
config: BatchNormConfig,
) -> BatchNorm {
let vs = vs.borrow();
BatchNorm {
config,
running_mean: vs.zeros_no_train("running_mean", &[out_dim]),
running_var: vs.ones_no_train("running_var", &[out_dim]),
ws: vs.var("weight", &[out_dim], config.ws_init),
bs: vs.var("bias", &[out_dim], config.bs_init),
nd,
}
}
pub fn batch_norm1d<'a, T: Borrow<super::Path<'a>>>(
vs: T,
out_dim: i64,
config: BatchNormConfig,
) -> BatchNorm {
batch_norm(vs, 1, out_dim, config)
}
pub fn batch_norm2d<'a, T: Borrow<super::Path<'a>>>(
vs: T,
out_dim: i64,
config: BatchNormConfig,
) -> BatchNorm {
batch_norm(vs, 2, out_dim, config)
}
pub fn batch_norm3d<'a, T: Borrow<super::Path<'a>>>(
vs: T,
out_dim: i64,
config: BatchNormConfig,
) -> BatchNorm {
batch_norm(vs, 3, out_dim, config)
}
impl super::module::ModuleT for BatchNorm {
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
let dim = xs.dim();
if self.nd == 1 && dim != 2 && dim != 3 {
panic!(
"expected an input tensor with 2 or 3 dims, got {:?}",
xs.size()
)
}
if self.nd > 1 && xs.dim() != self.nd + 2 {
panic!(
"expected an input tensor with {} dims, got {:?}",
self.nd + 2,
xs.size()
)
};
Tensor::batch_norm(
xs,
Some(&self.ws),
Some(&self.bs),
Some(&self.running_mean),
Some(&self.running_var),
train,
self.config.momentum,
self.config.eps,
self.config.cudnn_enabled,
)
}
}