use ndarray::{ArrayView2, Axis};
use serde_derive::{Deserialize, Serialize};
use std::f32::MIN;
#[derive(Serialize, Deserialize)]
pub enum CostFunc {
MSE,
MAE,
Accuracy,
CrossEntropy,
}
impl From<String> for CostFunc {
fn from(name: String) -> Self {
match name.to_lowercase().as_str() {
"mse" => CostFunc::MSE,
"mae" => CostFunc::MAE,
"accuracy" => CostFunc::Accuracy,
"crossentropy" => CostFunc::CrossEntropy,
_ => panic!("Unknown cost function: {}", &name)
}
}
}
impl std::default::Default for CostFunc {
fn default() -> Self {
CostFunc::MSE
}
}
pub fn cross_entropy(y_true: ArrayView2<f32>, y_hat: ArrayView2<f32>) -> f32 {
let elipson = 1e-15;
let mut y_hat = y_hat
.mapv(|v| if v > 1.0 { v - elipson } else { v + elipson })
.to_owned();
y_hat = y_hat.to_owned()
/ y_hat
.sum_axis(Axis(1))
.into_shape((y_hat.shape()[0], 1))
.unwrap()
.to_owned();
-(y_true.to_owned() * y_hat.mapv(|v| v.ln()))
.sum_axis(Axis(1))
.sum()
/ y_hat.rows() as f32
}
pub fn single_cross_entropy(y_true: f32, y_hat: f32) -> f32 {
let y_hat = if y_hat > 1.0 {
1.0 - 1e-15
} else if y_hat < 0.0 {
0.0 + 1e-15
} else {
y_hat
};
-(y_hat.ln() + (1. - y_true) * (1. - y_hat).ln())
}
pub fn accuracy_score(y_true: ArrayView2<f32>, y_hat: ArrayView2<f32>) -> f32 {
y_true
.outer_iter()
.zip(y_hat.outer_iter())
.map(|(yt, yh)| {
if yt.len() > 1 {
let (ytrue_argmax, _max) =
yt.iter()
.enumerate()
.fold((None, MIN), |(idx, acc), (i, x)| {
if x > &acc {
(Some(i), *x)
} else {
(idx, acc)
}
});
let (yhat_argmax, _max) =
yh.iter()
.enumerate()
.fold((None, MIN), |(idx, acc), (i, x)| {
if x > &acc {
(Some(i), *x)
} else {
(idx, acc)
}
});
accuracy(
ytrue_argmax.unwrap_or(0) as f32,
yhat_argmax.unwrap_or(0) as f32,
)
} else {
accuracy(yt[0], yh[0])
}
})
.sum::<f32>()
/ y_true.rows() as f32
}
pub fn accuracy(y_true: f32, y_hat: f32) -> f32 {
if y_hat == y_true {
1.
} else {
0.
}
}
pub fn mean_squared_error(y_true: ArrayView2<f32>, y_hat: ArrayView2<f32>) -> f32 {
y_true
.iter()
.zip(y_hat.iter())
.map(|(yt, yh)| squared_error(*yt, *yh))
.sum::<f32>()
/ y_true.rows() as f32
}
pub fn squared_error(y_true: f32, y_hat: f32) -> f32 {
(y_true - y_hat).powf(2.0)
}
pub fn mean_absolute_error(y_true: ArrayView2<f32>, y_hat: ArrayView2<f32>) -> f32 {
y_true
.iter()
.zip(y_hat.iter())
.map(|(yt, yh)| absolute_error(*yt, *yh))
.sum::<f32>()
/ y_true.rows() as f32
}
pub fn absolute_error(y_true: f32, y_hat: f32) -> f32 {
(y_true - y_hat).abs()
}