1use crate::{AdditionalOps, BaseOps, ClipOp, FnsOps, Matrix, SumOps, SumOverOps};
2use custos::{number::Float, Device, Shape};
3
4pub trait CCE<T> {
5 fn cce(&self, targets: &Matrix<T>) -> (T, Matrix<T>);
6}
7
8impl<'a, T, D> Matrix<'a, T, D> where D: Device {}
9
10pub trait CCEOp<T, S: Shape = (), D = Self>: Device
11where
12 D: Device,
13{
14 #[inline]
15 fn cce<'a>(
16 &self,
17 preds: &Matrix<'a, T, D, S>,
18 targets: &Matrix<'a, T, D, S>,
19 ) -> (T, Matrix<'a, T, Self, S>) {
20 (self.cce_loss(preds, targets), self.cce_grad(preds, targets))
21 }
22 fn cce_loss(&self, preds: &Matrix<T, D, S>, targets: &Matrix<T, D, S>) -> T;
23 fn cce_grad<'a>(
24 &self,
25 preds: &Matrix<'a, T, D, S>,
26 targets: &Matrix<'a, T, D, S>,
27 ) -> Matrix<'a, T, Self, S>;
28}
29
30impl<'a, T, S: Shape, D: CCEOp<T, S>> Matrix<'a, T, D, S> {
31 #[inline]
32 pub fn cce(&self, targets: &Matrix<'a, T, D, S>) -> (T, Matrix<'a, T, D, S>) {
33 self.device().cce(self, targets)
34 }
35
36 #[inline]
37 pub fn cce_loss(&self, targets: &Matrix<T, D, S>) -> T {
38 self.device().cce_loss(self, targets)
39 }
40
41 #[inline]
42 pub fn cce_grad(&self, targets: &Matrix<'a, T, D, S>) -> Matrix<'a, T, D, S> {
43 self.device().cce_grad(self, targets)
44 }
45}
46
47impl<T, D, IS: Shape> CCEOp<T, IS> for D
48where
49 T: Float,
50 D: FnsOps<T>
51 + ClipOp<T, IS>
52 + BaseOps<T, IS>
53 + SumOps<T>
54 + SumOverOps<T, IS>
55 + AdditionalOps<T, IS>
56 + FnsOps<T, IS>,
57{
58 fn cce_loss(&self, preds: &Matrix<T, D, IS>, targets: &Matrix<T, D, IS>) -> T {
59 let preds = preds.clip(T::as_generic(1E-7), T::as_generic(1. - 1E-7));
60 let confidences = (&preds * targets).sum_cols::<()>();
61 confidences.ln().neg().mean()
62 }
63
64 fn cce_grad<'a>(
65 &self,
66 preds: &Matrix<'a, T, D, IS>,
67 targets: &Matrix<'a, T, D, IS>,
68 ) -> Matrix<'a, T, Self, IS> {
69 let grad = (targets / preds).neg();
70 grad / T::from_usize(preds.rows())
71 }
72}
73
74