concision_traits/loss.rs
1/*
2 Appellation: loss <module>
3 Contrib: @FL03
4*/
5
6/// The [`Loss`] trait defines a common interface for any custom loss function implementations.
7/// This trait requires the implementor to define their algorithm for calculating the loss
8/// between two values, `lhs` and `rhs`, which can be of different types, `X` and `Y`
9/// respectively. These terms are used generically to allow for flexibility in the allowed
10/// types, such as tensors, scalars, or other data structures while clearly defining the "order"
11/// in which the operations are performed. It is most common to expect the `lhs` to be the
12/// predicted output and the `rhs` to be the actual output, but this is not a strict requirement.
13/// The trait also defines an associated type `Output`, which represents the type of the loss
14/// value returned by the `loss` method. This allows for different loss functions to return
15/// different types of loss values, such as scalars or tensors, depending on the specific
16/// implementation of the loss function.
17pub trait Loss<X, Y> {
18 type Output;
19 /// compute the loss between two values, `lhs` and `rhs`
20 fn loss(&self, lhs: &X, rhs: &Y) -> Self::Output;
21}
22
23/// A trait for computing the mean absolute error of a tensor or array
24pub trait MeanAbsoluteError {
25 type Output;
26
27 fn mae(&self) -> Self::Output;
28}
29/// A trait for computing the mean squared error of a tensor or array
30pub trait MeanSquaredError {
31 type Output;
32
33 fn mse(&self) -> Self::Output;
34}
35
36/*
37 ************* Implementations *************
38*/
39
40use ndarray::{ArrayBase, Data, Dimension};
41use num_traits::{Float, FromPrimitive};
42
43impl<A, S, D> MeanAbsoluteError for ArrayBase<S, D, A>
44where
45 A: 'static + Float + FromPrimitive,
46 D: Dimension,
47 S: Data<Elem = A>,
48{
49 type Output = A;
50
51 fn mae(&self) -> Self::Output {
52 self.abs().mean().unwrap()
53 }
54}
55
56impl<A, S, D> MeanSquaredError for ArrayBase<S, D, A>
57where
58 A: 'static + Float + FromPrimitive,
59 D: Dimension,
60 S: Data<Elem = A>,
61{
62 type Output = A;
63
64 fn mse(&self) -> Self::Output {
65 self.pow2().mean().unwrap()
66 }
67}