#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MetricError {
Empty,
ShapeMismatch,
InvalidProbabilities,
}
pub fn mse_f32(prediction: &[f32], target: &[f32]) -> Result<f32, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != target.len() {
return Err(MetricError::ShapeMismatch);
}
let mut acc = 0.0f32;
for i in 0..prediction.len() {
let d = prediction[i] - target[i];
acc += d * d;
}
Ok(acc / prediction.len() as f32)
}
pub fn mse_f64(prediction: &[f64], target: &[f64]) -> Result<f64, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != target.len() {
return Err(MetricError::ShapeMismatch);
}
let mut acc = 0.0f64;
for i in 0..prediction.len() {
let d = prediction[i] - target[i];
acc += d * d;
}
Ok(acc / prediction.len() as f64)
}
pub fn mae_f32(prediction: &[f32], target: &[f32]) -> Result<f32, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != target.len() {
return Err(MetricError::ShapeMismatch);
}
let mut acc = 0.0f32;
for i in 0..prediction.len() {
acc += (prediction[i] - target[i]).abs();
}
Ok(acc / prediction.len() as f32)
}
pub fn mae_f64(prediction: &[f64], target: &[f64]) -> Result<f64, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != target.len() {
return Err(MetricError::ShapeMismatch);
}
let mut acc = 0.0f64;
for i in 0..prediction.len() {
acc += (prediction[i] - target[i]).abs();
}
Ok(acc / prediction.len() as f64)
}
pub fn argmax_f32(values: &[f32]) -> Option<usize> {
if values.is_empty() {
return None;
}
let mut best_idx = 0usize;
let mut best_value = values[0];
for (idx, value) in values.iter().enumerate().skip(1) {
if *value > best_value {
best_value = *value;
best_idx = idx;
}
}
Some(best_idx)
}
pub fn argmax_f64(values: &[f64]) -> Option<usize> {
if values.is_empty() {
return None;
}
let mut best_idx = 0usize;
let mut best_value = values[0];
for (idx, value) in values.iter().enumerate().skip(1) {
if *value > best_value {
best_value = *value;
best_idx = idx;
}
}
Some(best_idx)
}
pub fn accuracy_top1_from_one_hot_f32(
prediction: &[f32],
one_hot_target: &[f32],
) -> Result<f32, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != one_hot_target.len() {
return Err(MetricError::ShapeMismatch);
}
let pred_idx = argmax_f32(prediction).ok_or(MetricError::Empty)?;
let target_idx = argmax_f32(one_hot_target).ok_or(MetricError::Empty)?;
Ok(if pred_idx == target_idx { 1.0 } else { 0.0 })
}
pub fn accuracy_top1_from_one_hot_f64(
prediction: &[f64],
one_hot_target: &[f64],
) -> Result<f64, MetricError> {
if prediction.is_empty() {
return Err(MetricError::Empty);
}
if prediction.len() != one_hot_target.len() {
return Err(MetricError::ShapeMismatch);
}
let pred_idx = argmax_f64(prediction).ok_or(MetricError::Empty)?;
let target_idx = argmax_f64(one_hot_target).ok_or(MetricError::Empty)?;
Ok(if pred_idx == target_idx { 1.0 } else { 0.0 })
}
pub fn cross_entropy_from_probabilities_f32(
probabilities: &[f32],
one_hot_target: &[f32],
eps: f32,
) -> Result<f32, MetricError> {
if probabilities.is_empty() {
return Err(MetricError::Empty);
}
if probabilities.len() != one_hot_target.len() {
return Err(MetricError::ShapeMismatch);
}
if !eps.is_finite() || eps <= 0.0 {
return Err(MetricError::InvalidProbabilities);
}
let mut loss = 0.0f32;
for i in 0..probabilities.len() {
let p = if probabilities[i] < eps {
eps
} else {
probabilities[i]
};
if !p.is_finite() || p <= 0.0 {
return Err(MetricError::InvalidProbabilities);
}
loss -= one_hot_target[i] * crate::math::lnf(p);
}
Ok(loss)
}
pub fn cross_entropy_from_probabilities_f64(
probabilities: &[f64],
one_hot_target: &[f64],
eps: f64,
) -> Result<f64, MetricError> {
if probabilities.is_empty() {
return Err(MetricError::Empty);
}
if probabilities.len() != one_hot_target.len() {
return Err(MetricError::ShapeMismatch);
}
if !eps.is_finite() || eps <= 0.0 {
return Err(MetricError::InvalidProbabilities);
}
let mut loss = 0.0f64;
for i in 0..probabilities.len() {
let p = if probabilities[i] < eps {
eps
} else {
probabilities[i]
};
if !p.is_finite() || p <= 0.0 {
return Err(MetricError::InvalidProbabilities);
}
loss -= one_hot_target[i] * crate::math::lnd(p);
}
Ok(loss)
}