concision_core/loss/traits/entropy.rs
1/*
2 Appellation: loss <module>
3 Contrib: @FL03
4*/
5
6/// A trait for computing the cross-entropy loss of a tensor or array
7pub trait CrossEntropy {
8 type Output;
9
10 fn cross_entropy(&self) -> Self::Output;
11}
12
13/*
14 ************* Implementations *************
15*/
16
17use ndarray::{ArrayBase, Data, Dimension, ScalarOperand};
18use num_traits::{Float, FromPrimitive};
19
20impl<A, S, D> CrossEntropy for ArrayBase<S, D>
21where
22 A: Float + FromPrimitive + ScalarOperand,
23 D: Dimension,
24 S: Data<Elem = A>,
25{
26 type Output = A;
27
28 fn cross_entropy(&self) -> Self::Output {
29 self.mapv(|x| -x.ln()).mean().unwrap()
30 }
31}