use super::Learnable;
use ndarray::{Dimension, Ix2};
use rand::thread_rng;
use rand_distr::{Distribution, Normal, Uniform};
pub fn calculate_gain(non_linearity: &str) -> f32 {
match non_linearity {
"linear" | "sigmoid" => 1.0,
"tanh" => 5.0 / 3.0,
"relu" => 2.0_f32.sqrt(),
"leaky_relu" => (2.0 / (1.0 + 0.01_f32.powi(2))).sqrt(),
_ => panic!("error: unsupported nonlinearity: {}", non_linearity),
}
}
pub fn calculate_fan_in_fan_out<D: Dimension>(param: &Learnable<D>) -> (f32, f32) {
let data = param.data();
let shape = data.shape();
let num_input_fmaps = shape[1];
let num_output_fmaps = shape[0];
let (fan_in, fan_out) = {
if shape.len() > 2 {
let numel = shape.iter().skip(2).sum::<usize>();
(num_input_fmaps * numel, num_output_fmaps * numel)
} else {
(num_input_fmaps, num_output_fmaps)
}
};
(fan_in as f32, fan_out as f32)
}
pub fn constant<D: Dimension>(param: &Learnable<D>, value: f32) {
param.data_mut().map_inplace(|el| *el = value);
}
pub fn zeros<D: Dimension>(param: &Learnable<D>) {
param.data_mut().map_inplace(|el| *el = 0.);
}
pub fn ones<D: Dimension>(param: &Learnable<D>) {
param.data_mut().map_inplace(|el| *el = 1.0);
}
pub fn eye(param: &Learnable<Ix2>) {
for ((x, y), el) in param.data_mut().indexed_iter_mut() {
if x == y {
*el = 1.
} else {
*el = 0.
}
}
}
pub fn dirac<D: Dimension>(param: &Learnable<D>, groups: usize) {
let mut data = param.data_mut();
let shape = data.shape().to_vec();
let no_dim = shape.len();
if !(3..=5).contains(&no_dim) {
panic!("error: only 3, 4 and 5 dimensional parameters are supported.");
}
assert_eq!(
shape[0].rem_euclid(groups),
0,
"error: output channels must be divisible by groups."
);
let out_channels_per_groups = shape[0] / groups;
let min_dim = out_channels_per_groups.min(shape[1]);
for g in 0..groups {
for d in 0..min_dim {
let mut index = D::zeros(no_dim);
index[0] = g * out_channels_per_groups + d;
index[1] = d;
index
.slice_mut()
.iter_mut()
.skip(2)
.zip(shape.iter().skip(2))
.for_each(|(el, sh)| *el = sh / 2);
data[index] = 1.
}
}
}
pub fn uniform<D: Dimension>(param: &Learnable<D>, low: f32, high: f32) {
let unif_dstr = Uniform::new(low, high);
let mut t_rng = thread_rng();
param
.data_mut()
.map_inplace(|el| *el = unif_dstr.sample(&mut t_rng));
}
pub fn normal<D: Dimension>(param: &Learnable<D>, mean: f32, std: f32) {
let norm_dstr = Normal::new(mean, std).unwrap();
let mut t_rng = thread_rng();
param
.data_mut()
.map_inplace(|el| *el = norm_dstr.sample(&mut t_rng));
}
pub fn xavier_uniform<D: Dimension>(param: &Learnable<D>, gain: f32) {
let (fan_in, fan_out) = calculate_fan_in_fan_out(param);
let std = gain * (2. / ((fan_in + fan_out) as f32)).sqrt();
let a = 3.0_f32.sqrt() * std;
let unif_distr = Uniform::new(-a, a);
let mut t_rng = thread_rng();
param
.data_mut()
.map_inplace(|el| *el = unif_distr.sample(&mut t_rng));
}
pub fn xavier_normal<D: Dimension>(param: &Learnable<D>, gain: f32) {
let (fan_in, fan_out) = calculate_fan_in_fan_out(param);
let std = gain * (2. / ((fan_in + fan_out) as f32)).sqrt();
let norm_distr = Normal::new(0., std).unwrap();
let mut t_rng = thread_rng();
param
.data_mut()
.map_inplace(|el| *el = norm_distr.sample(&mut t_rng));
}