concision_core/loss/traits/
entropy.rs

1/*
2    Appellation: loss <module>
3    Contrib: @FL03
4*/
5
6/// A trait for computing the cross-entropy loss of a tensor or array
7pub trait CrossEntropy {
8    type Output;
9
10    fn cross_entropy(&self) -> Self::Output;
11}
12
13/*
14 ************* Implementations *************
15*/
16
17use ndarray::{ArrayBase, Data, Dimension, ScalarOperand};
18use num_traits::{Float, FromPrimitive};
19
20impl<A, S, D> CrossEntropy for ArrayBase<S, D>
21where
22    A: Float + FromPrimitive + ScalarOperand,
23    D: Dimension,
24    S: Data<Elem = A>,
25{
26    type Output = A;
27
28    fn cross_entropy(&self) -> Self::Output {
29        self.mapv(|x| -x.ln()).mean().unwrap()
30    }
31}