#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MetricError {
Empty,
ShapeMismatch,
InvalidProbabilities,
}
pub fn mse(prediction: &[f32], target: &[f32]) -> Result<f32, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != target.len() {
return Err(MetricError::ShapeMismatch);
}
let mut acc = 0.0f32;
for i in 0..prediction.len() {
let d = prediction[i] - target[i];
acc += d * d;
}
Ok(acc / prediction.len() as f32)
}
pub fn mae(prediction: &[f32], target: &[f32]) -> Result<f32, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != target.len() {
return Err(MetricError::ShapeMismatch);
}
let mut acc = 0.0f32;
for i in 0..prediction.len() {
acc += (prediction[i] - target[i]).abs();
}
Ok(acc / prediction.len() as f32)
}
pub fn argmax(values: &[f32]) -> Option<usize> {
if values.is_empty() {
return None;
}
let mut best_idx = 0usize;
let mut best_value = values[0];
for (idx, value) in values.iter().enumerate().skip(1) {
if *value > best_value {
best_value = *value;
best_idx = idx;
}
}
Some(best_idx)
}
pub fn accuracy_top1_from_one_hot(prediction: &[f32], one_hot_target: &[f32]) -> Result<f32, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != one_hot_target.len() {
return Err(MetricError::ShapeMismatch);
}
let pred_idx = argmax(prediction).ok_or(MetricError::Empty)?;
let target_idx = argmax(one_hot_target).ok_or(MetricError::Empty)?;
Ok(if pred_idx == target_idx { 1.0 } else { 0.0 })
}
pub fn cross_entropy_from_probabilities(probabilities: &[f32], one_hot_target: &[f32], eps: f32) -> Result<f32, MetricError> {
if probabilities.is_empty() {
return Err(MetricError::Empty);
}
if probabilities.len() != one_hot_target.len() {
return Err(MetricError::ShapeMismatch);
}
if !eps.is_finite() || eps <= 0.0 {
return Err(MetricError::InvalidProbabilities);
}
let mut loss = 0.0f32;
for i in 0..probabilities.len() {
let p = if probabilities[i] < eps { eps } else { probabilities[i] };
if !p.is_finite() || p <= 0.0 {
return Err(MetricError::InvalidProbabilities);
}
loss -= one_hot_target[i] * crate::math::lnf(p);
}
Ok(loss)
}