ice_nine/
loss.rs

1use crate::Loss;
2use ndarray::Array1;
3
4pub struct LeastSquares;
5
6impl Loss<Array1<f64>> for LeastSquares {
7    fn l(&self, output: &Array1<f64>, target: &Array1<f64>) -> f64 {
8        output
9            .iter()
10            .zip(target)
11            .map(|(a, b)| (a - b).powi(2))
12            .sum()
13    }
14
15    fn d_l(&self, output: &Array1<f64>, target: &Array1<f64>) -> Array1<f64> {
16        2.0 * (output.clone() - target)
17    }
18}
19
20pub struct CrossEntropy {
21    pub temperature: f64,
22}
23
24/// Takes labels (class indices) as the target
25impl CrossEntropy {
26    const EPSILON: f64 = 1e-7;
27}
28
29impl Loss<usize> for CrossEntropy {
30    /// CE Loss, [output] is logits, [target] is the label of the answer
31    fn l(&self, output: &Array1<f64>, target: &usize) -> f64 {
32        assert!(*target < output.len());
33        let output = output / self.temperature;
34        let max_logit = output.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
35        let exps = output.mapv(|x| (x - max_logit).exp());
36        let prob = (exps[*target] / exps.sum()).clamp(Self::EPSILON, 1.0);
37        -prob.ln()
38    }
39
40    fn d_l(&self, output: &Array1<f64>, target: &usize) -> Array1<f64> {
41        assert!(*target < output.len());
42        let output = output / self.temperature;
43        let max_logit = output.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
44        let exps = output.mapv(|x| (x - max_logit).exp());
45        let exp_sum = exps.sum();
46        let raw_vec = (0..exps.len())
47            .map(|i| {
48                if i == *target {
49                    (exps[*target] - exp_sum) / (self.temperature * exp_sum)
50                } else {
51                    exps[i] / (self.temperature * exp_sum)
52                }
53            })
54            .collect();
55        Array1::from_vec(raw_vec)
56    }
57}
58
59pub fn logits_to_probs(logits: &Array1<f64>) -> Array1<f64> {
60    let max_logit = logits.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
61    let exps = logits.mapv(|x| (x - max_logit).exp());
62    let exp_sum = exps.sum();
63    exps / exp_sum
64}