use af;
use af::Array;
use error::HALError;
pub fn l2_vec(pred: &Array, target: &Array) -> Array{
let diff = af::sub(pred, target, false);
af::mul(&diff, &diff, false)
}
pub fn mse_vec(pred: &Array, target: &Array) -> Array {
af::mul(&l2_vec(pred, target), &0.5f32, false)
}
pub fn cross_entropy_vec(pred: &Array, target: &Array) -> Array {
let pos = af::mul(&af::mul(&-1.0, target, false)
, &af::log(&pred), false); let neg = af::mul(&af::sub(&1.0, target, false) , &af::log(&(af::sub(&1.0, pred, false))), false);
af::sub(&pos, &neg, false)
}
pub fn l2(pred: &Array, target: &Array) -> f32 {
af::sum_all(&l2_vec(pred, target)).0 as f32
}
pub fn mse(pred: &Array, target: &Array) -> f32 {
0.5f32 * af::mean_all(&l2_vec(pred, target)).0 as f32
}
pub fn cross_entropy(pred: &Array, target: &Array) -> f32 {
af::sum_all(&cross_entropy_vec(pred, target)).0 as f32
}
pub fn mse_derivative(pred: &Array, target: &Array) -> Array {
af::sub(pred, target, false)
}
pub fn l2_derivative(pred: &Array, target: &Array) -> Array {
af::mul(&mse_derivative(pred, target), &2.0f32, false)
}
pub fn cross_entropy_derivative(pred: &Array, target: &Array) -> Array {
mse_derivative(pred, target)
}
pub fn get_loss(name: &str, pred: &Array, target: &Array) -> Result<f32, HALError> {
match name {
"l2" => Ok(l2(pred, target)),
"mse" => Ok(mse(pred, target)),
"cross_entropy" => Ok(cross_entropy(pred, target)),
_ => Err(HALError::UNKNOWN),
}
}
pub fn get_loss_vec(name: &str, pred: &Array, target: &Array) -> Result<Array, HALError> {
match name {
"l2" => Ok(l2_vec(pred, target)),
"mse" => Ok(mse_vec(pred, target)),
"cross_entropy" => Ok(cross_entropy_vec(pred, target)),
_ => Err(HALError::UNKNOWN),
}
}
pub fn get_loss_derivative(name: &str, pred: &Array, target: &Array) -> Result<Array, HALError> {
match name {
"l2" => Ok(l2_derivative(pred, target)),
"mse" => Ok(mse_derivative(pred, target)),
"cross_entropy" => Ok(cross_entropy_derivative(pred, target)),
_ => Err(HALError::UNKNOWN),
}
}