Skip to main content

rnn/metrics/
metrics.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum MetricError {
3    Empty,
4    ShapeMismatch,
5    InvalidProbabilities,
6}
7
8pub fn mse(prediction: &[f32], target: &[f32]) -> Result<f32, MetricError> {
9    if prediction.is_empty() {
10        return Err(MetricError::Empty);
11    }
12    if prediction.len() != target.len() {
13        return Err(MetricError::ShapeMismatch);
14    }
15
16    let mut acc = 0.0f32;
17    for i in 0..prediction.len() {
18        let d = prediction[i] - target[i];
19        acc += d * d;
20    }
21    Ok(acc / prediction.len() as f32)
22}
23
24pub fn mae(prediction: &[f32], target: &[f32]) -> Result<f32, MetricError> {
25    if prediction.is_empty() {
26        return Err(MetricError::Empty);
27    }
28    if prediction.len() != target.len() {
29        return Err(MetricError::ShapeMismatch);
30    }
31
32    let mut acc = 0.0f32;
33    for i in 0..prediction.len() {
34        acc += (prediction[i] - target[i]).abs();
35    }
36    Ok(acc / prediction.len() as f32)
37}
38
39pub fn argmax(values: &[f32]) -> Option<usize> {
40    if values.is_empty() {
41        return None;
42    }
43    let mut best_idx = 0usize;
44    let mut best_value = values[0];
45    for (idx, value) in values.iter().enumerate().skip(1) {
46        if *value > best_value {
47            best_value = *value;
48            best_idx = idx;
49        }
50    }
51    Some(best_idx)
52}
53
54pub fn accuracy_top1_from_one_hot(prediction: &[f32], one_hot_target: &[f32]) -> Result<f32, MetricError> {
55    if prediction.is_empty() {
56        return Err(MetricError::Empty);
57    }
58    if prediction.len() != one_hot_target.len() {
59        return Err(MetricError::ShapeMismatch);
60    }
61
62    let pred_idx = argmax(prediction).ok_or(MetricError::Empty)?;
63    let target_idx = argmax(one_hot_target).ok_or(MetricError::Empty)?;
64    Ok(if pred_idx == target_idx { 1.0 } else { 0.0 })
65}
66
67pub fn cross_entropy_from_probabilities(probabilities: &[f32], one_hot_target: &[f32], eps: f32) -> Result<f32, MetricError> {
68    if probabilities.is_empty() {
69        return Err(MetricError::Empty);
70    }
71    if probabilities.len() != one_hot_target.len() {
72        return Err(MetricError::ShapeMismatch);
73    }
74    if !eps.is_finite() || eps <= 0.0 {
75        return Err(MetricError::InvalidProbabilities);
76    }
77
78    let mut loss = 0.0f32;
79    for i in 0..probabilities.len() {
80        let p = if probabilities[i] < eps { eps } else { probabilities[i] };
81        if !p.is_finite() || p <= 0.0 {
82            return Err(MetricError::InvalidProbabilities);
83        }
84        loss -= one_hot_target[i] * crate::math::lnf(p);
85    }
86    Ok(loss)
87}