use af;
use rand;
use rand::Rng;
use af::{Dim4, Array, HasAfEnum};
use utils;
use error::HALError;
pub fn get_fans(dims: Dim4) -> (f32, f32){
let ndims = dims.ndims();
let fan_in = match ndims {
2 => dims[0],
_ => dims.get()[1..ndims].iter().fold(1, |prod, x| prod * x) as u64,
};
let fan_out = match dims[1] {
2 => dims[1],
_ => dims[0],
};
(fan_in as f32, fan_out as f32)
}
pub fn normal<T: HasAfEnum>(dims: Dim4, scale: f32) -> Array {
let mut rng = rand::thread_rng();
af::set_seed(rng.gen::<u64>());
let src_type = T::get_af_dtype();
let scale_vec = utils::constant(dims, src_type, scale);
let u = af::mul(&af::randn::<T>(dims), &scale_vec, false);
let dst_type = u.get_type();
assert!(src_type == dst_type
, "type mismatch detected in normal, {:?} vs {:?}"
, src_type, dst_type);
u
}
pub fn uniform<T: HasAfEnum>(dims: Dim4, scale: f32) -> Array{
let mut rng = rand::thread_rng();
af::set_seed(rng.gen::<u64>());
let src_type = T::get_af_dtype();
let scale_vec = utils::constant(dims, src_type, scale);
let u = af::sub(&af::mul(&af::randu::<T>(dims), &scale_vec, false)
, &scale, false);
let dst_type = u.get_type();
assert!(src_type == dst_type
, "type mismatch detected in uniform, {:?} vs {:?}"
, src_type, dst_type);
u
}
pub fn zeros<T: HasAfEnum>(dims: Dim4) -> Array {
utils::constant(dims, T::get_af_dtype(), 0.0f32)
}
pub fn ones<T: HasAfEnum>(dims: Dim4) -> Array {
utils::constant(dims, T::get_af_dtype(), 1.0f32)
}
pub fn glorot_uniform<T: HasAfEnum>(dims: Dim4) -> Array {
let (fan_in, fan_out) = get_fans(dims);
let s = (6.0f32 / (fan_in + fan_out)).sqrt();
uniform::<T>(dims, s)
}
pub fn glorot_normal<T: HasAfEnum>(dims: Dim4) -> Array {
let (fan_in, fan_out) = get_fans(dims);
let s = (2.0f32 / (fan_in + fan_out)).sqrt();
normal::<T>(dims, s)
}
pub fn lecun_uniform<T: HasAfEnum>(dims: Dim4) -> Array {
let (fan_in, _) = get_fans(dims);
let s = 3.0f32 / fan_in;
uniform::<T>(dims, s)
}
pub fn get_initialization<T: HasAfEnum>(name: &str, dims: Dim4) -> Result<Array, HALError>
{
match name {
"glorot_uniform" => Ok(glorot_uniform::<T>(dims)),
"glorot_normal" => Ok(glorot_normal::<T>(dims)),
"lecun_uniform" => Ok(lecun_uniform::<T>(dims)),
"normal" => Ok(normal::<T>(dims, 0.05f32)), "uniform" => Ok(uniform::<T>(dims, 0.05f32)), "zeros" => Ok(zeros::<T>(dims)),
"ones" => Ok(ones::<T>(dims)),
_ => Err(HALError::UNKNOWN),
}
}