Skip to main content

meuron/cost/
cross_entropy.rs

1use crate::backend::Backend;
2use crate::backend::unary_ops;
3use crate::cost::Cost;
4use ndarray::Dimension;
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Copy, Serialize, Deserialize)]
8pub struct CrossEntropy;
9
10impl<B: Backend> Cost<B> for CrossEntropy {
11    fn loss<D: Dimension>(&self, predicted: &B::Tensor<D>, target: &B::Tensor<D>) -> f32 {
12        let eps = 1e-15_f32;
13        let clipped = B::clamp(predicted, eps, 1.0 - eps);
14        let ln_clip = B::unary(&clipped, unary_ops::LN);
15        -B::mean(&B::mul(target, &ln_clip)).unwrap_or(0.0)
16    }
17
18    fn gradient<D: Dimension>(
19        &self,
20        predicted: &B::Tensor<D>,
21        target: &B::Tensor<D>,
22    ) -> B::Tensor<D> {
23        let eps = 1e-15_f32;
24        let clipped = B::clamp(predicted, eps, 1.0 - eps);
25        B::div(&B::scale(target, -1.0), &clipped)
26    }
27}