use crate::primitives::Matrix;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Average {
Macro,
Micro,
Weighted,
}
#[must_use]
#[provable_contracts_macros::contract("metrics-classification-v1", equation = "accuracy")]
pub fn accuracy(y_pred: &[usize], y_true: &[usize]) -> f32 {
contract_pre_accuracy!();
assert_eq!(y_pred.len(), y_true.len(), "Vectors must have same length");
assert!(!y_true.is_empty(), "Vectors cannot be empty");
let correct = y_pred
.iter()
.zip(y_true.iter())
.filter(|(p, t)| p == t)
.count();
correct as f32 / y_true.len() as f32
}
#[must_use]
#[provable_contracts_macros::contract("metrics-classification-v1", equation = "precision")]
pub fn precision(y_pred: &[usize], y_true: &[usize], average: Average) -> f32 {
contract_pre_precision!();
assert_eq!(y_pred.len(), y_true.len(), "Vectors must have same length");
assert!(!y_true.is_empty(), "Vectors cannot be empty");
let n_classes = y_true
.iter()
.chain(y_pred.iter())
.max()
.map_or(0, |&m| m + 1);
if n_classes == 0 {
return 0.0;
}
let (tp, fp, _, support) = compute_tp_fp_fn(y_pred, y_true, n_classes);
match average {
Average::Micro => {
let total_tp: usize = tp.iter().sum();
let total_fp: usize = fp.iter().sum();
if total_tp + total_fp == 0 {
0.0
} else {
total_tp as f32 / (total_tp + total_fp) as f32
}
}
Average::Macro => {
let precisions: Vec<f32> = (0..n_classes)
.map(|i| {
if tp[i] + fp[i] == 0 {
0.0
} else {
tp[i] as f32 / (tp[i] + fp[i]) as f32
}
})
.collect();
precisions.iter().sum::<f32>() / n_classes as f32
}
Average::Weighted => {
let total_support: usize = support.iter().sum();
if total_support == 0 {
return 0.0;
}
(0..n_classes)
.map(|i| {
let prec = if tp[i] + fp[i] == 0 {
0.0
} else {
tp[i] as f32 / (tp[i] + fp[i]) as f32
};
prec * support[i] as f32 / total_support as f32
})
.sum()
}
}
}
#[must_use]
#[provable_contracts_macros::contract("metrics-classification-v1", equation = "recall")]
pub fn recall(y_pred: &[usize], y_true: &[usize], average: Average) -> f32 {
contract_pre_recall!();
assert_eq!(y_pred.len(), y_true.len(), "Vectors must have same length");
assert!(!y_true.is_empty(), "Vectors cannot be empty");
let n_classes = y_true
.iter()
.chain(y_pred.iter())
.max()
.map_or(0, |&m| m + 1);
if n_classes == 0 {
return 0.0;
}
let (tp, _, fn_counts, support) = compute_tp_fp_fn(y_pred, y_true, n_classes);
match average {
Average::Micro => {
let total_tp: usize = tp.iter().sum();
let total_fn: usize = fn_counts.iter().sum();
if total_tp + total_fn == 0 {
0.0
} else {
total_tp as f32 / (total_tp + total_fn) as f32
}
}
Average::Macro => {
let recalls: Vec<f32> = (0..n_classes)
.map(|i| {
if tp[i] + fn_counts[i] == 0 {
0.0
} else {
tp[i] as f32 / (tp[i] + fn_counts[i]) as f32
}
})
.collect();
recalls.iter().sum::<f32>() / n_classes as f32
}
Average::Weighted => {
let total_support: usize = support.iter().sum();
if total_support == 0 {
return 0.0;
}
(0..n_classes)
.map(|i| {
let rec = if tp[i] + fn_counts[i] == 0 {
0.0
} else {
tp[i] as f32 / (tp[i] + fn_counts[i]) as f32
};
rec * support[i] as f32 / total_support as f32
})
.sum()
}
}
}
fn class_precision(tp: usize, fp: usize) -> f32 {
if tp + fp == 0 {
0.0
} else {
tp as f32 / (tp + fp) as f32
}
}
fn class_recall(tp: usize, fn_count: usize) -> f32 {
if tp + fn_count == 0 {
0.0
} else {
tp as f32 / (tp + fn_count) as f32
}
}
fn f1_from_prec_rec(precision: f32, recall: f32) -> f32 {
if precision + recall == 0.0 {
0.0
} else {
2.0 * precision * recall / (precision + recall)
}
}
fn class_f1(tp: usize, fp: usize, fn_count: usize) -> f32 {
let prec = class_precision(tp, fp);
let rec = class_recall(tp, fn_count);
f1_from_prec_rec(prec, rec)
}
#[must_use]
#[provable_contracts_macros::contract("metrics-classification-v1", equation = "f1_score")]
pub fn f1_score(y_pred: &[usize], y_true: &[usize], average: Average) -> f32 {
contract_pre_f1_score!();
assert_eq!(y_pred.len(), y_true.len(), "Vectors must have same length");
assert!(!y_true.is_empty(), "Vectors cannot be empty");
let n_classes = y_true
.iter()
.chain(y_pred.iter())
.max()
.map_or(0, |&m| m + 1);
if n_classes == 0 {
return 0.0;
}
let (tp, fp, fn_counts, support) = compute_tp_fp_fn(y_pred, y_true, n_classes);
match average {
Average::Micro => {
let total_tp: usize = tp.iter().sum();
let total_fp: usize = fp.iter().sum();
let total_fn: usize = fn_counts.iter().sum();
class_f1(total_tp, total_fp, total_fn)
}
Average::Macro => {
let f1_sum: f32 = (0..n_classes)
.map(|i| class_f1(tp[i], fp[i], fn_counts[i]))
.sum();
f1_sum / n_classes as f32
}
Average::Weighted => {
let total_support: usize = support.iter().sum();
if total_support == 0 {
return 0.0;
}
(0..n_classes)
.map(|i| {
let f1 = class_f1(tp[i], fp[i], fn_counts[i]);
f1 * support[i] as f32 / total_support as f32
})
.sum()
}
}
}
#[must_use]
pub fn precision_per_class(y_pred: &[usize], y_true: &[usize]) -> Vec<f32> {
assert_eq!(y_pred.len(), y_true.len(), "Vectors must have same length");
assert!(!y_true.is_empty(), "Vectors cannot be empty");
let n_classes = y_true
.iter()
.chain(y_pred.iter())
.max()
.map_or(0, |&m| m + 1);
let (tp, fp, _, _) = compute_tp_fp_fn(y_pred, y_true, n_classes);
(0..n_classes)
.map(|i| class_precision(tp[i], fp[i]))
.collect()
}
#[must_use]
pub fn recall_per_class(y_pred: &[usize], y_true: &[usize]) -> Vec<f32> {
assert_eq!(y_pred.len(), y_true.len(), "Vectors must have same length");
assert!(!y_true.is_empty(), "Vectors cannot be empty");
let n_classes = y_true
.iter()
.chain(y_pred.iter())
.max()
.map_or(0, |&m| m + 1);
let (tp, _, fn_counts, _) = compute_tp_fp_fn(y_pred, y_true, n_classes);
(0..n_classes)
.map(|i| class_recall(tp[i], fn_counts[i]))
.collect()
}
include!("classification_include_01.rs");
#[cfg(test)]
#[path = "tests_classification_contract.rs"]
mod tests_classification_contract;