concision_core/func/loss/reg/
avg.rs

1/*
2    Appellation: avg <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::math::{Abs, Squared};
6use nd::prelude::*;
7use nd::{Data, ScalarOperand};
8use num::traits::{FromPrimitive, Num, Pow, Signed};
9
10pub trait MeanAbsoluteError<Rhs = Self> {
11    type Output;
12
13    fn mae(&self, target: &Rhs) -> Self::Output;
14}
15
16pub trait MeanSquaredError<Rhs = Self> {
17    type Output;
18
19    fn mse(&self, target: &Rhs) -> Self::Output;
20}
21
22losses! {
23    impl<A, S, D> MSE::<ArrayBase<S, D>, ArrayBase<S, D>, Output = Option<A>>(MeanSquaredError::mse)
24    where
25        A: FromPrimitive + Num + Pow<i32, Output = A> + ScalarOperand,
26        D: Dimension,
27        S: Data<Elem = A>,
28}
29
30losses! {
31    impl<A, S, D> MAE::<ArrayBase<S, D>, ArrayBase<S, D>, Output = Option<A>>(MeanAbsoluteError::mae)
32    where
33        A: FromPrimitive + Num + ScalarOperand + Signed,
34        D: Dimension,
35        S: Data<Elem = A>,
36}
37
38/*
39 ************* Implementations *************
40*/
41impl<A, S, D> MeanAbsoluteError<ArrayBase<S, D>> for ArrayBase<S, D>
42where
43    A: FromPrimitive + Num + ScalarOperand + Signed,
44    D: Dimension,
45    S: Data<Elem = A>,
46{
47    type Output = Option<A>;
48
49    fn mae(&self, target: &ArrayBase<S, D>) -> Self::Output {
50        (target - self).abs().mean()
51    }
52}
53
54impl<A, S, D> MeanSquaredError<ArrayBase<S, D>> for ArrayBase<S, D>
55where
56    A: FromPrimitive + Num + Pow<i32, Output = A> + ScalarOperand,
57    D: Dimension,
58    S: Data<Elem = A>,
59{
60    type Output = Option<A>;
61
62    fn mse(&self, target: &ArrayBase<S, D>) -> Self::Output {
63        (target - self).sqrd().mean()
64    }
65}