concision_core/traits/
entropy.rs1pub trait CrossEntropy {
8    type Output;
9
10    fn cross_entropy(&self) -> Self::Output;
11}
12pub trait MeanAbsoluteError {
14    type Output;
15
16    fn mae(&self) -> Self::Output;
17}
18pub trait MeanSquaredError {
20    type Output;
21
22    fn mse(&self) -> Self::Output;
23}
24
25use ndarray::{ArrayBase, Data, Dimension, ScalarOperand};
30use num_traits::{Float, FromPrimitive};
31
32impl<A, S, D> CrossEntropy for ArrayBase<S, D>
33where
34    A: Float + FromPrimitive + ScalarOperand,
35    D: Dimension,
36    S: Data<Elem = A>,
37{
38    type Output = A;
39
40    fn cross_entropy(&self) -> Self::Output {
41        self.mapv(|x| -x.ln()).mean().unwrap()
42    }
43}
44
45impl<A, S, D> MeanAbsoluteError for ArrayBase<S, D>
46where
47    A: Float + FromPrimitive + ScalarOperand,
48    D: Dimension,
49    S: Data<Elem = A>,
50{
51    type Output = A;
52
53    fn mae(&self) -> Self::Output {
54        self.abs().mean().unwrap()
55    }
56}
57
58impl<A, S, D> MeanSquaredError for ArrayBase<S, D>
59where
60    A: Float + FromPrimitive + ScalarOperand,
61    D: Dimension,
62    S: Data<Elem = A>,
63{
64    type Output = A;
65
66    fn mse(&self) -> Self::Output {
67        self.pow2().mean().unwrap()
68    }
69}