1#[derive(Clone, Copy, Debug, PartialEq)]
2pub enum LossKind {
3 Mse,
4 Mae,
5 Huber { delta: f32 },
6}
7
8#[derive(Clone, Copy, Debug, PartialEq, Eq)]
9pub enum LossError {
10 Empty,
11 ShapeMismatch,
12 NonFinite,
13}
14
15pub fn loss_and_gradient(
16 kind: LossKind,
17 prediction: &[f32],
18 target: &[f32],
19 grad_out: &mut [f32],
20) -> Result<f32, LossError> {
21 if prediction.is_empty() {
22 return Err(LossError::Empty);
23 }
24 if prediction.len() != target.len() || grad_out.len() < prediction.len() {
25 return Err(LossError::ShapeMismatch);
26 }
27
28 let n = prediction.len() as f32;
29 let mut loss = 0.0f32;
30
31 match kind {
32 LossKind::Mse => {
33 let scale = 2.0 / n;
34 for i in 0..prediction.len() {
35 let diff = prediction[i] - target[i];
36 loss += diff * diff;
37 grad_out[i] = scale * diff;
38 }
39 loss /= n;
40 }
41 LossKind::Mae => {
42 let inv_n = 1.0 / n;
43 for i in 0..prediction.len() {
44 let diff = prediction[i] - target[i];
45 loss += diff.abs();
46 grad_out[i] = if diff > 0.0 {
47 inv_n
48 } else if diff < 0.0 {
49 -inv_n
50 } else {
51 0.0
52 };
53 }
54 loss *= inv_n;
55 }
56 LossKind::Huber { delta } => {
57 if !delta.is_finite() || delta <= 0.0 {
58 return Err(LossError::NonFinite);
59 }
60 let inv_n = 1.0 / n;
61 for i in 0..prediction.len() {
62 let diff = prediction[i] - target[i];
63 let ad = diff.abs();
64 if ad <= delta {
65 loss += 0.5 * diff * diff;
66 grad_out[i] = diff * inv_n;
67 } else {
68 loss += delta * (ad - 0.5 * delta);
69 grad_out[i] = if diff > 0.0 {
70 delta * inv_n
71 } else {
72 -delta * inv_n
73 };
74 }
75 }
76 loss *= inv_n;
77 }
78 }
79
80 if !loss.is_finite() {
81 return Err(LossError::NonFinite);
82 }
83
84 Ok(loss)
85}