1use crate::Loss;
2use ndarray::Array1;
3
4pub struct LeastSquares;
5
6impl Loss<Array1<f64>> for LeastSquares {
7 fn l(&self, output: &Array1<f64>, target: &Array1<f64>) -> f64 {
8 output
9 .iter()
10 .zip(target)
11 .map(|(a, b)| (a - b).powi(2))
12 .sum()
13 }
14
15 fn d_l(&self, output: &Array1<f64>, target: &Array1<f64>) -> Array1<f64> {
16 2.0 * (output.clone() - target)
17 }
18}
19
20pub struct CrossEntropy {
21 pub temperature: f64,
22}
23
24impl CrossEntropy {
26 const EPSILON: f64 = 1e-7;
27}
28
29impl Loss<usize> for CrossEntropy {
30 fn l(&self, output: &Array1<f64>, target: &usize) -> f64 {
32 assert!(*target < output.len());
33 let output = output / self.temperature;
34 let max_logit = output.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
35 let exps = output.mapv(|x| (x - max_logit).exp());
36 let prob = (exps[*target] / exps.sum()).clamp(Self::EPSILON, 1.0);
37 -prob.ln()
38 }
39
40 fn d_l(&self, output: &Array1<f64>, target: &usize) -> Array1<f64> {
41 assert!(*target < output.len());
42 let output = output / self.temperature;
43 let max_logit = output.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
44 let exps = output.mapv(|x| (x - max_logit).exp());
45 let exp_sum = exps.sum();
46 let raw_vec = (0..exps.len())
47 .map(|i| {
48 if i == *target {
49 (exps[*target] - exp_sum) / (self.temperature * exp_sum)
50 } else {
51 exps[i] / (self.temperature * exp_sum)
52 }
53 })
54 .collect();
55 Array1::from_vec(raw_vec)
56 }
57}
58
59pub fn logits_to_probs(logits: &Array1<f64>) -> Array1<f64> {
60 let max_logit = logits.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
61 let exps = logits.mapv(|x| (x - max_logit).exp());
62 let exp_sum = exps.sum();
63 exps / exp_sum
64}