Skip to main content

rnn/losses/
losses.rs

1#[derive(Clone, Copy, Debug, PartialEq)]
2pub enum LossKind {
3    Mse,
4    Mae,
5    Huber { delta: f32 },
6}
7
8#[derive(Clone, Copy, Debug, PartialEq, Eq)]
9pub enum LossError {
10    Empty,
11    ShapeMismatch,
12    NonFinite,
13}
14
15pub fn loss_and_gradient(
16    kind: LossKind,
17    prediction: &[f32],
18    target: &[f32],
19    grad_out: &mut [f32],
20) -> Result<f32, LossError> {
21    if prediction.is_empty() {
22        return Err(LossError::Empty);
23    }
24    if prediction.len() != target.len() || grad_out.len() < prediction.len() {
25        return Err(LossError::ShapeMismatch);
26    }
27
28    let n = prediction.len() as f32;
29    let mut loss = 0.0f32;
30
31    match kind {
32        LossKind::Mse => {
33            let scale = 2.0 / n;
34            for i in 0..prediction.len() {
35                let diff = prediction[i] - target[i];
36                loss += diff * diff;
37                grad_out[i] = scale * diff;
38            }
39            loss /= n;
40        }
41        LossKind::Mae => {
42            let inv_n = 1.0 / n;
43            for i in 0..prediction.len() {
44                let diff = prediction[i] - target[i];
45                loss += diff.abs();
46                grad_out[i] = if diff > 0.0 {
47                    inv_n
48                } else if diff < 0.0 {
49                    -inv_n
50                } else {
51                    0.0
52                };
53            }
54            loss *= inv_n;
55        }
56        LossKind::Huber { delta } => {
57            if !delta.is_finite() || delta <= 0.0 {
58                return Err(LossError::NonFinite);
59            }
60            let inv_n = 1.0 / n;
61            for i in 0..prediction.len() {
62                let diff = prediction[i] - target[i];
63                let ad = diff.abs();
64                if ad <= delta {
65                    loss += 0.5 * diff * diff;
66                    grad_out[i] = diff * inv_n;
67                } else {
68                    loss += delta * (ad - 0.5 * delta);
69                    grad_out[i] = if diff > 0.0 {
70                        delta * inv_n
71                    } else {
72                        -delta * inv_n
73                    };
74                }
75            }
76            loss *= inv_n;
77        }
78    }
79
80    if !loss.is_finite() {
81        return Err(LossError::NonFinite);
82    }
83
84    Ok(loss)
85}