custos_math/ops/nn/loss/
cce.rs

1use crate::{AdditionalOps, BaseOps, ClipOp, FnsOps, Matrix, SumOps, SumOverOps};
2use custos::{number::Float, Device, Shape};
3
4pub trait CCE<T> {
5    fn cce(&self, targets: &Matrix<T>) -> (T, Matrix<T>);
6}
7
8impl<'a, T, D> Matrix<'a, T, D> where D: Device {}
9
10pub trait CCEOp<T, S: Shape = (), D = Self>: Device
11where
12    D: Device,
13{
14    #[inline]
15    fn cce<'a>(
16        &self,
17        preds: &Matrix<'a, T, D, S>,
18        targets: &Matrix<'a, T, D, S>,
19    ) -> (T, Matrix<'a, T, Self, S>) {
20        (self.cce_loss(preds, targets), self.cce_grad(preds, targets))
21    }
22    fn cce_loss(&self, preds: &Matrix<T, D, S>, targets: &Matrix<T, D, S>) -> T;
23    fn cce_grad<'a>(
24        &self,
25        preds: &Matrix<'a, T, D, S>,
26        targets: &Matrix<'a, T, D, S>,
27    ) -> Matrix<'a, T, Self, S>;
28}
29
30impl<'a, T, S: Shape, D: CCEOp<T, S>> Matrix<'a, T, D, S> {
31    #[inline]
32    pub fn cce(&self, targets: &Matrix<'a, T, D, S>) -> (T, Matrix<'a, T, D, S>) {
33        self.device().cce(self, targets)
34    }
35
36    #[inline]
37    pub fn cce_loss(&self, targets: &Matrix<T, D, S>) -> T {
38        self.device().cce_loss(self, targets)
39    }
40
41    #[inline]
42    pub fn cce_grad(&self, targets: &Matrix<'a, T, D, S>) -> Matrix<'a, T, D, S> {
43        self.device().cce_grad(self, targets)
44    }
45}
46
47impl<T, D, IS: Shape> CCEOp<T, IS> for D
48where
49    T: Float,
50    D: FnsOps<T>
51        + ClipOp<T, IS>
52        + BaseOps<T, IS>
53        + SumOps<T>
54        + SumOverOps<T, IS>
55        + AdditionalOps<T, IS>
56        + FnsOps<T, IS>,
57{
58    fn cce_loss(&self, preds: &Matrix<T, D, IS>, targets: &Matrix<T, D, IS>) -> T {
59        let preds = preds.clip(T::as_generic(1E-7), T::as_generic(1. - 1E-7));
60        let confidences = (&preds * targets).sum_cols::<()>();
61        confidences.ln().neg().mean()
62    }
63
64    fn cce_grad<'a>(
65        &self,
66        preds: &Matrix<'a, T, D, IS>,
67        targets: &Matrix<'a, T, D, IS>,
68    ) -> Matrix<'a, T, Self, IS> {
69        let grad = (targets / preds).neg();
70        grad / T::from_usize(preds.rows())
71    }
72}
73
74/*
75
76impl<T: Float + CDatatype> CCE<T> for Matrix<'_, T>
77where
78    Box<dyn CCEOp<T>>: CCEOp<T>,
79{
80    fn cce(&self, targets: &Matrix<T>) -> (T, Matrix<T>) {
81        let device = get_device!(self.device(), CCEOp<T>);
82        let loss = cce(device, self, targets);
83        let grad = cce_grad(device, self, targets);
84        (loss, grad)
85    }
86}
87
88pub trait CCEOp<T>: FnsOps<T> + ClipOp<T> + BaseOps<T> + SumOps<T> + AdditionalOps<T> {}
89impl<T: Float + CDatatype> CCEOp<T> for CPU {}
90#[cfg(feature = "opencl")]
91impl<T: Float + CDatatype> CCEOp<T> for OpenCL {}
92#[cfg(feature = "cuda")]
93impl<T: Float + CDatatype> CCEOp<T> for custos::CUDA {}
94
95pub fn cce<T: Float>(device: &dyn CCEOp<T>, preds: &Matrix<T>, targets: &Matrix<T>) -> T {
96    let preds = device.clip(preds, T::as_generic(1E-7), T::as_generic(1. - 1E-7));
97    let confidences = device.sum_cols(&device.mul(&preds, targets));
98    device.mean(&device.neg(&device.ln(&confidences)))
99}
100
101pub fn cce_grad<'a, T: Float>(
102    device: &'a dyn CCEOp<T>,
103    preds: &Matrix<T>,
104    targets: &Matrix<T>,
105) -> Matrix<'a, T> {
106    let grad = device.neg(&device.div(targets, preds));
107    device.divs(&grad, T::from_usize(preds.rows()))
108}
109*/