1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum MetricError {
3 Empty,
4 ShapeMismatch,
5 InvalidProbabilities,
6}
7
8pub fn mse(prediction: &[f32], target: &[f32]) -> Result<f32, MetricError> {
9 if prediction.is_empty() {
10 return Err(MetricError::Empty);
11 }
12 if prediction.len() != target.len() {
13 return Err(MetricError::ShapeMismatch);
14 }
15
16 let mut acc = 0.0f32;
17 for i in 0..prediction.len() {
18 let d = prediction[i] - target[i];
19 acc += d * d;
20 }
21 Ok(acc / prediction.len() as f32)
22}
23
24pub fn mae(prediction: &[f32], target: &[f32]) -> Result<f32, MetricError> {
25 if prediction.is_empty() {
26 return Err(MetricError::Empty);
27 }
28 if prediction.len() != target.len() {
29 return Err(MetricError::ShapeMismatch);
30 }
31
32 let mut acc = 0.0f32;
33 for i in 0..prediction.len() {
34 acc += (prediction[i] - target[i]).abs();
35 }
36 Ok(acc / prediction.len() as f32)
37}
38
39pub fn argmax(values: &[f32]) -> Option<usize> {
40 if values.is_empty() {
41 return None;
42 }
43 let mut best_idx = 0usize;
44 let mut best_value = values[0];
45 for (idx, value) in values.iter().enumerate().skip(1) {
46 if *value > best_value {
47 best_value = *value;
48 best_idx = idx;
49 }
50 }
51 Some(best_idx)
52}
53
54pub fn accuracy_top1_from_one_hot(prediction: &[f32], one_hot_target: &[f32]) -> Result<f32, MetricError> {
55 if prediction.is_empty() {
56 return Err(MetricError::Empty);
57 }
58 if prediction.len() != one_hot_target.len() {
59 return Err(MetricError::ShapeMismatch);
60 }
61
62 let pred_idx = argmax(prediction).ok_or(MetricError::Empty)?;
63 let target_idx = argmax(one_hot_target).ok_or(MetricError::Empty)?;
64 Ok(if pred_idx == target_idx { 1.0 } else { 0.0 })
65}
66
67pub fn cross_entropy_from_probabilities(probabilities: &[f32], one_hot_target: &[f32], eps: f32) -> Result<f32, MetricError> {
68 if probabilities.is_empty() {
69 return Err(MetricError::Empty);
70 }
71 if probabilities.len() != one_hot_target.len() {
72 return Err(MetricError::ShapeMismatch);
73 }
74 if !eps.is_finite() || eps <= 0.0 {
75 return Err(MetricError::InvalidProbabilities);
76 }
77
78 let mut loss = 0.0f32;
79 for i in 0..probabilities.len() {
80 let p = if probabilities[i] < eps { eps } else { probabilities[i] };
81 if !p.is_finite() || p <= 0.0 {
82 return Err(MetricError::InvalidProbabilities);
83 }
84 loss -= one_hot_target[i] * crate::math::lnf(p);
85 }
86 Ok(loss)
87}