use std::fmt::Debug;
use dyn_clone::DynClone;
use ndarray::{ArrayD, ArrayViewD};
use serde::{Deserialize, Serialize};
use crate::{
core::{MininnError, NNResult},
registers::REGISTER,
};
use super::NNUtil;
pub trait CostFunction: NNUtil + CostCore + Debug + DynClone {}
pub trait CostCore {
fn function(&self, y_p: &ArrayViewD<f32>, y: &ArrayViewD<f32>) -> f32;
fn derivate(&self, y_p: &ArrayViewD<f32>, y: &ArrayViewD<f32>) -> ArrayD<f32>;
}
dyn_clone::clone_trait_object!(CostFunction);
impl PartialEq for Box<dyn CostFunction> {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name()
}
}
impl Eq for Box<dyn CostFunction> {}
impl Serialize for Box<dyn CostFunction> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.name())
}
}
impl<'de> Deserialize<'de> for Box<dyn CostFunction> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let cost: String = Deserialize::deserialize(deserializer)?;
let cost = REGISTER.with(|register| {
register.borrow_mut().create_cost(&cost).map_err(|err| {
serde::de::Error::custom(format!(
"Failed to create cost function '{}': {}",
cost, err
))
})
});
cost
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Cost {
MSE,
MAE,
BCE,
CCE,
}
impl CostFunction for Cost {}
impl NNUtil for Cost {
#[inline]
fn name(&self) -> &str {
match self {
Cost::MSE => "MSE",
Cost::MAE => "MAE",
Cost::BCE => "BCE",
Cost::CCE => "CCE",
}
}
#[inline]
fn from_name(name: &str) -> NNResult<Box<Self>>
where
Self: Sized,
{
match name {
"MSE" => Ok(Box::new(Cost::MSE)),
"MAE" => Ok(Box::new(Cost::MAE)),
"BCE" => Ok(Box::new(Cost::BCE)),
"CCE" => Ok(Box::new(Cost::CCE)),
_ => Err(MininnError::CostError(
"The cost function is not supported".to_string(),
)),
}
}
}
impl CostCore for Cost {
#[inline]
fn function(&self, y_p: &ArrayViewD<f32>, y: &ArrayViewD<f32>) -> f32 {
match self {
Cost::MSE => (y - y_p).pow2().mean().unwrap_or(0.),
Cost::MAE => (y - y_p).abs().mean().unwrap_or(0.),
Cost::BCE => -((y * y_p.ln() + (1. - y) * (1. - y_p).ln()).sum()),
Cost::CCE => -(y * y_p.ln()).sum(),
}
}
#[inline]
fn derivate(&self, y_p: &ArrayViewD<f32>, y: &ArrayViewD<f32>) -> ArrayD<f32> {
match self {
Cost::MSE => 2.0 * (y_p - y) / y.len() as f32,
Cost::MAE => (y_p - y).signum() / y.len() as f32,
Cost::BCE => y_p - y,
Cost::CCE => y_p - y,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use mininn_derive::CostFunction;
use ndarray::array;
#[test]
fn test_cost_name() {
let cost = Cost::MSE;
assert_eq!(cost.name(), "MSE");
}
#[test]
fn test_mse_function() {
let y_p = array![0.1, 0.4, 0.6].into_dyn();
let y = array![0.0, 0.5, 1.0].into_dyn();
let cost = Cost::MSE;
let result = cost.function(&y_p.view(), &y.view());
assert_eq!(result, 0.05999999);
}
#[test]
fn test_mae_function() {
let y_p = array![0.1, 0.4, 0.6].into_dyn();
let y = array![0.0, 0.5, 1.0].into_dyn();
let cost = Cost::MAE;
let result = cost.function(&y_p.view(), &y.view());
assert_eq!(result, 0.19999999); }
#[test]
fn test_bce_function() {
let y_p = array![0.07, 0.91, 0.74, 0.23, 0.85, 0.17, 0.94].into_dyn();
let y = array![0., 1., 1., 0., 0., 1., 1.].into_dyn();
let cost = Cost::BCE;
let result = cost.function(&y_p.view(), &y.view());
assert_eq!(result, 4.460303459760249);
}
#[test]
fn test_mse_derivate() {
let y_p = array![0.1, 0.4, 0.6].into_dyn();
let y = array![0.0, 0.5, 1.0].into_dyn();
let cost = Cost::MSE;
let result = cost.derivate(&y_p.view(), &y.view());
let expected = array![0.066666667, -0.066666666, -0.266666665].into_dyn();
assert_eq!(result.mapv(|v| v), expected.mapv(|v| v));
}
#[test]
fn test_mae_derivate() {
let y_p = array![0.1, 0.4, 0.6].into_dyn();
let y = array![0.0, 0.5, 1.0].into_dyn();
let cost = Cost::MAE;
let result = cost.derivate(&y_p.view(), &y.view());
let expected = array![0.33333334, -0.33333334, -0.33333334].into_dyn();
assert_eq!(result.mapv(|v| v), expected.mapv(|v| v));
}
#[test]
fn test_bce_derivate() {
let y_p = array![0.9, 0.1, 0.8, 0.2].into_dyn();
let y = array![1., 0., 1., 0.].into_dyn();
let cost = Cost::BCE;
let result = cost.derivate(&y_p.view(), &y.view());
let expected = array![-0.100000024, 0.1, -0.19999999, 0.2].into_dyn();
assert_eq!(result.mapv(|v| v), expected.mapv(|v| v));
}
#[test]
fn test_cce_derivate() {
let y_p = array![0.9, 0.1, 0.8, 0.2].into_dyn();
let y = array![1., 0., 1., 0.].into_dyn();
let cost = Cost::BCE;
let result = cost.derivate(&y_p.view(), &y.view());
let expected = array![-0.100000024, 0.1, -0.19999999, 0.2].into_dyn();
assert_eq!(result.mapv(|v| v), expected.mapv(|v| v));
}
#[test]
fn test_custom_cost() {
#[derive(CostFunction, Debug, Clone)]
struct CustomCost;
impl CostCore for CustomCost {
fn function(&self, y_p: &ArrayViewD<f32>, y: &ArrayViewD<f32>) -> f32 {
(y - y_p).abs().mean().unwrap_or(0.)
}
fn derivate(&self, y_p: &ArrayViewD<f32>, y: &ArrayViewD<f32>) -> ArrayD<f32> {
(y_p - y).signum() / y.len() as f32
}
}
let y_p = array![0.1, 0.4, 0.6].into_dyn();
let y = array![0.0, 0.5, 1.0].into_dyn();
let cost = CustomCost;
assert_eq!(cost.name(), "CustomCost");
let result = cost.function(&y_p.view(), &y.view());
assert_eq!(result, 0.19999999);
let result = cost.derivate(&y_p.view(), &y.view());
let expected = array![0.33333334, -0.33333334, -0.33333334];
assert_eq!(result.mapv(|v| v), expected.into_dyn().mapv(|v| v));
}
}