concision_core/loss/traits/
standard.rs

1/*
2    Appellation: loss <module>
3    Contrib: @FL03
4*/
5
6/// Compute the mean absolute error (MAE) of the object; more formally, we define the MAE as
7/// the average of the absolute differences between the predicted and actual values:
8///
9/// $$
10/// Err = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i|
11/// $$
12pub trait MeanAbsoluteError {
13    type Output;
14
15    fn mae(&self) -> Self::Output;
16}
17/// The [`MeanSquaredError`] (MSE) is the average of the squared differences between the
18/// ($`\hat{y_{i}}`$) and actual values ($`y_{i}`$):
19///
20/// $$
21/// Err = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2
22/// $$
23pub trait MeanSquaredError {
24    type Output;
25
26    fn mse(&self) -> Self::Output;
27}
28
29/*
30 ************* Implementations *************
31*/
32
33use ndarray::{ArrayBase, Data, Dimension, ScalarOperand};
34use num_traits::{Float, FromPrimitive};
35
36impl<A, S, D> MeanAbsoluteError for ArrayBase<S, D>
37where
38    A: Float + FromPrimitive + ScalarOperand,
39    D: Dimension,
40    S: Data<Elem = A>,
41{
42    type Output = A;
43
44    fn mae(&self) -> Self::Output {
45        self.abs().mean().unwrap()
46    }
47}
48
49impl<A, S, D> MeanSquaredError for ArrayBase<S, D>
50where
51    A: Float + FromPrimitive + ScalarOperand,
52    D: Dimension,
53    S: Data<Elem = A>,
54{
55    type Output = A;
56
57    fn mse(&self) -> Self::Output {
58        self.pow2().mean().unwrap()
59    }
60}