concision_core/traits/
entropy.rs

1/*
2    Appellation: loss <module>
3    Contrib: @FL03
4*/
5
6/// A trait for computing the cross-entropy loss of a tensor or array
7pub trait CrossEntropy {
8    type Output;
9
10    fn cross_entropy(&self) -> Self::Output;
11}
12/// A trait for computing the mean absolute error of a tensor or array
13pub trait MeanAbsoluteError {
14    type Output;
15
16    fn mae(&self) -> Self::Output;
17}
18/// A trait for computing the mean squared error of a tensor or array
19pub trait MeanSquaredError {
20    type Output;
21
22    fn mse(&self) -> Self::Output;
23}
24
25/*
26 ************* Implementations *************
27*/
28
29use ndarray::{ArrayBase, Data, Dimension, ScalarOperand};
30use num_traits::{Float, FromPrimitive};
31
32impl<A, S, D> CrossEntropy for ArrayBase<S, D>
33where
34    A: Float + FromPrimitive + ScalarOperand,
35    D: Dimension,
36    S: Data<Elem = A>,
37{
38    type Output = A;
39
40    fn cross_entropy(&self) -> Self::Output {
41        self.mapv(|x| -x.ln()).mean().unwrap()
42    }
43}
44
45impl<A, S, D> MeanAbsoluteError for ArrayBase<S, D>
46where
47    A: Float + FromPrimitive + ScalarOperand,
48    D: Dimension,
49    S: Data<Elem = A>,
50{
51    type Output = A;
52
53    fn mae(&self) -> Self::Output {
54        self.abs().mean().unwrap()
55    }
56}
57
58impl<A, S, D> MeanSquaredError for ArrayBase<S, D>
59where
60    A: Float + FromPrimitive + ScalarOperand,
61    D: Dimension,
62    S: Data<Elem = A>,
63{
64    type Output = A;
65
66    fn mse(&self) -> Self::Output {
67        self.pow2().mean().unwrap()
68    }
69}