concision_core/loss/
standard.rs

1/*
2    Appellation: loss <module>
3    Contrib: @FL03
4*/
5
6/// Compute the mean absolute error (MAE) of the object.
7pub trait MeanAbsoluteError {
8    type Output;
9
10    fn mae(&self) -> Self::Output;
11}
12/// Compute the mean squared error (MSE) of the object.
13pub trait MeanSquaredError {
14    type Output;
15
16    fn mse(&self) -> Self::Output;
17}
18
19/*
20 ************* Implementations *************
21*/
22
23use ndarray::{ArrayBase, Data, Dimension, ScalarOperand};
24use num_traits::{Float, FromPrimitive};
25
26impl<A, S, D> MeanAbsoluteError for ArrayBase<S, D>
27where
28    A: Float + FromPrimitive + ScalarOperand,
29    D: Dimension,
30    S: Data<Elem = A>,
31{
32    type Output = A;
33
34    fn mae(&self) -> Self::Output {
35        self.abs().mean().unwrap()
36    }
37}
38
39impl<A, S, D> MeanSquaredError for ArrayBase<S, D>
40where
41    A: Float + FromPrimitive + ScalarOperand,
42    D: Dimension,
43    S: Data<Elem = A>,
44{
45    type Output = A;
46
47    fn mse(&self) -> Self::Output {
48        self.pow2().mean().unwrap()
49    }
50}