Function autograd::ops::batch_norm [] [src]

pub fn batch_norm(x: &Tensor, scale: &Tensor, shift: &Tensor) -> Tensor

Applies batch normalization.

scale and shift should be shared variables. Since normalization is performed along 1st axis of x, both of them should have shape (1, x.shape[1])

extern crate ndarray;
extern crate autograd as ag;

let mut ctx = ag::Context::new();
let ref x = ag::standard_normal(&[3, 4]);
let ref scale = ag::variable(ag::ndarray_ext::ones(&[1, 4]), &mut ctx);
let ref shift = ag::variable(ag::ndarray_ext::zeros(&[1, 4]), &mut ctx);
let ref norm = ag::batch_norm(x, scale, shift);

assert_eq!(norm.eval(&mut ctx).shape(), &[3, 4]);