use af;
use af::{Array};
use num::Complex;
use utils;
use error::HALError;
pub fn tanh(x: &Array) -> Array {
af::tanh(x)
}
pub fn sigmoid(x: &Array) -> Array {
let neg_one = utils::constant(x.dims(), x.get_type(), -1.0f32);
let one = utils::constant(x.dims(), x.get_type(), 1.0f32);
let exp_m_e = af::exp(&af::mul(&neg_one, x, false));
let denominator = af::add(&one, &exp_m_e, false);
let a = af::div(&one, &denominator, false);
utils::assert_types(vec![x, &a]);
a
}
pub fn softmax(x: &Array) -> Array {
let exponentiated = af::exp(&x);
let sum_epx_x = af::sum_all(&exponentiated).0 as f32;
let sum_exp_x_vec = utils::constant(x.dims(), x.get_type(), sum_epx_x);
let a = af::div(&exponentiated, &sum_exp_x_vec, false);
utils::assert_types(vec![x, &a]);
a
}
pub fn lrelu(x: &Array) -> Array {
let zero_one = utils::constant(x.dims(), x.get_type(), 0.01f32);
let scaled = af::mul(x, &zero_one, false);
let a = af::select(&scaled , &af::lt(x, &0.0f32, false) , x); utils::assert_types(vec![x, &a]);
a
}
pub fn lrelu_derivative(x: &Array) -> Array {
let x_lt_zero = utils::constant(x.dims(), x.get_type(), 0.01f32);
let one = utils::constant(x.dims(), x.get_type(), 1.0f32);
let grad = af::select(&one, &af::gt(x, &0.0f32, false), &x_lt_zero);
utils::assert_types(vec![x, &grad]);
grad
}
pub fn relu(x: &Array) -> Array {
let zero = utils::constant(x.dims(), x.get_type(), 0.0f32);
let a = af::select(&zero, &af::lt(x, &0.0, false), x);
utils::assert_types(vec![x, &a]);
a
}
pub fn relu_derivative(x: &Array) -> Array {
let zero = utils::constant(x.dims(), x.get_type(), 0.0f32);
let one = utils::constant(x.dims(), x.get_type(), 1.0f32);
let grad = af::select(&one, &af::gt(x, &0.0f32, false), &zero);
utils::assert_types(vec![x, &grad]);
grad
}
pub fn tanh_derivative(x: &Array) -> Array {
let one = utils::constant(x.dims(), x.get_type(), 1.0f32);
let grad = af::sub(&one, &af::mul(x, x, false), false);
utils::assert_types(vec![x, &grad]);
grad
}
pub fn sigmoid_derivative(x: &Array) -> Array {
let one = utils::constant(x.dims(), x.get_type(), 1.0f32);
let grad = af::mul(x, &af::sub(&one, x, false), false);
utils::assert_types(vec![x, &grad]);
grad
}
pub fn softmax_derivative(x: &Array) -> Array {
sigmoid_derivative(x)
}
pub fn ones(x: &Array) -> Array {
x.clone()
}
pub fn ones_derivative(x: &Array) -> Array {
let grad = utils::constant(x.dims(), x.get_type(), 1.0f32);
utils::assert_types(vec![x, &grad]);
grad
}
pub fn is_smooth(name: &str) -> bool {
match name {
"softmax" => true,
"sigmoid" => true,
"relu" => false,
"lrelu" => false,
"tanh" => true,
"ones" => true,
"linear" => true,
_ => panic!("unknown function name provided"),
}
}
pub fn get_activation(name: &str, x: &Array) -> Result<Array, HALError> {
match name {
"softmax" => Ok(softmax(x)),
"sigmoid" => Ok(sigmoid(x)),
"relu" => Ok(relu(x)),
"lrelu" => Ok(lrelu(x)),
"tanh" => Ok(tanh(x)),
"ones" => Ok(ones(x)),
"linear" => Ok(ones(x)),
_ => Err(HALError::UNKNOWN),
}
}
pub fn get_derivative(name: &str, x: &Array) -> Result<Array, HALError> {
match name {
"softmax" => Ok(softmax_derivative(x)),
"sigmoid" => Ok(sigmoid_derivative(x)),
"relu" => Ok(relu_derivative(x)),
"lrelu" => Ok(lrelu_derivative(x)),
"tanh" => Ok(tanh_derivative(x)),
"ones" => Ok(ones_derivative(x)),
"linear" => Ok(ones_derivative(x)),
_ => Err(HALError::UNKNOWN),
}
}