use std::collections::{HashMap, HashSet};
pub fn mse(y_pred: &Vec<f64>, y_true: &Vec<f64>) -> f64 {
if y_pred.len() != y_true.len() {
panic!("Number of samples in predicted values does not match the number of samples in true values");
}
let n_samples = y_pred.len();
let mut mse = 0.0;
for i in 0..n_samples {
mse += (y_pred[i] - y_true[i]).powi(2);
}
mse / n_samples as f64
}
pub fn r2_score(y_pred: &Vec<f64>, y_true: &Vec<f64>) -> f64 {
if y_pred.len() != y_true.len() {
panic!("Number of samples in predicted values does not match the number of samples in true values");
}
let n_samples = y_pred.len();
let mut rss = 0.0;
let mut tss = 0.0;
let y_mean = y_true.iter().sum::<f64>() / n_samples as f64;
for i in 0..n_samples {
rss += (y_pred[i] - y_true[i]).powi(2);
tss += (y_true[i] - y_mean).powi(2);
}
1.0 - rss / tss as f64
}
pub fn accuracy(y_pred: &Vec<i32>, y_true: &Vec<i32>) -> HashMap<String, f64> {
if y_pred.len() != y_true.len() {
panic!("Number of samples in predicted values does not match the number of samples in true values");
}
let n_samples = y_pred.len();
let mut correct = 0;
let nunique = y_true.iter().collect::<HashSet<_>>().len();
let mut class_correct = vec![0; nunique];
let mut class_total = vec![0; nunique];
for i in 0..n_samples {
if y_pred[i] == y_true[i] {
correct += 1;
class_correct[y_true[i] as usize] += 1;
}
class_total[y_true[i] as usize] += 1;
}
let mut scores = HashMap::new();
scores.insert("Overall".to_string(), correct as f64 / n_samples as f64);
for i in 0..nunique {
scores.insert(
i.to_string(),
class_correct[i] as f64 / class_total[i] as f64,
);
}
scores
}
pub fn confusion_matrix(y_pred: &Vec<u8>, y_true: &Vec<u8>) -> Vec<Vec<u32>> {
if y_pred.len() != y_true.len() {
panic!("Number of samples in predicted values does not match the number of samples in true values");
}
let n_samples = y_pred.len();
let nunique = y_true.iter().collect::<HashSet<_>>().len();
let mut matrix = vec![vec![0; nunique]; nunique];
for i in 0..n_samples {
matrix[y_true[i] as usize][y_pred[i] as usize] += 1;
}
matrix
}
pub fn precision(y_pred: &Vec<u8>, y_true: &Vec<u8>, method: &str) -> HashMap<String, f64> {
if y_pred.len() != y_true.len() {
panic!("Number of samples in predicted values does not match the number of samples in true values");
}
let matrix = confusion_matrix(&y_pred, &y_true);
let nunique = y_true.iter().collect::<HashSet<_>>().len();
let mut p = 0.0;
let mut precision = HashMap::new();
if nunique == 1 {
precision.insert(
"overall".to_string(),
matrix[0][0] as f64 / (matrix[0][0] + matrix[0][1]) as f64,
);
} else {
match method {
"macro" => {
let mut p_sum = 0.0;
for i in 0..nunique {
for j in 0..nunique {
p += matrix[i][j] as f64;
}
p_sum += matrix[i][i] as f64;
precision.insert(format!("precision_{}", i), matrix[i][i] as f64 / p as f64);
}
precision.insert("overall".to_string(), p_sum as f64 / nunique as f64);
}
"micro" => {
let mut tp = 0;
let mut fp = 0;
for i in 0..nunique {
tp += matrix[i][i];
for j in 0..nunique {
fp += matrix[j][i];
}
precision.insert(format!("precision_{}", i), tp as f64 / (tp + fp) as f64);
}
precision.insert("overall".to_string(), tp as f64 / (tp + fp) as f64);
}
_ => panic!("Invalid method"),
}
}
precision
}
pub fn recall(y_pred: &Vec<u8>, y_true: &Vec<u8>, method: &str) -> HashMap<String, f64> {
if y_pred.len() != y_true.len() {
panic!("Number of samples in predicted values does not match the number of samples in true values");
}
let matrix = confusion_matrix(&y_pred, &y_true);
let nunique = y_true.iter().collect::<HashSet<_>>().len();
let mut r = 0.0;
let mut recall = HashMap::new();
if nunique == 1 {
recall.insert(
"overall".to_string(),
matrix[0][0] as f64 / (matrix[0][0] + matrix[1][0]) as f64,
);
} else {
match method {
"macro" => {
let mut r_sum = 0.0;
for i in 0..nunique {
for j in 0..nunique {
r += matrix[j][i] as f64;
}
r_sum += matrix[i][i] as f64;
recall.insert(format!("recall_{}", i), matrix[i][i] as f64 / r as f64);
}
recall.insert("overall".to_string(), r_sum as f64 / nunique as f64);
}
"micro" => {
let mut tp = 0;
let mut f_n = 0;
for i in 0..nunique {
tp += matrix[i][i];
for j in 0..nunique {
f_n += matrix[i][j];
}
recall.insert(format!("recall_{}", i), tp as f64 / (tp + f_n) as f64);
}
recall.insert("overall".to_string(), tp as f64 / (tp + f_n) as f64);
}
_ => panic!("Invalid method"),
}
}
recall
}
pub fn f1(y_pred: &Vec<u8>, y_true: &Vec<u8>, method: &str) -> HashMap<String, f64> {
let precision = precision(&y_pred, &y_true, method);
let recall = recall(&y_pred, &y_true, method);
let nunique = y_true.iter().collect::<HashSet<_>>().len();
let mut f1 = HashMap::new();
f1.insert(
"overall".to_string(),
2.0 * (precision.get("overall").unwrap() * recall.get("overall").unwrap())
/ (precision.get("overall").unwrap() + recall.get("overall").unwrap()),
);
for i in 0..nunique {
f1.insert(
format!("f1_{}", i),
2.0 * (precision.get(&format!("precision_{}", i)).unwrap()
* recall.get(&format!("recall_{}", i)).unwrap())
/ (precision.get(&format!("precision_{}", i)).unwrap()
+ recall.get(&format!("recall_{}", i)).unwrap()),
);
}
f1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mse() {
let y_pred = vec![6.9, 7.2];
let y_true = vec![6.0, 7.0];
let result = mse(&y_pred, &y_true);
assert_eq!(result, 0.425);
}
#[test]
fn test_r2_score() {
let y_pred = vec![6.1, 7.2];
let y_true = vec![6.0, 7.0];
let result = r2_score(&y_pred, &y_true);
assert_eq!(result, 0.9);
}
#[test]
fn test_accuracy() {
let y_pred = vec![0, 1, 1, 0, 1, 0, 1, 1, 1, 0];
let y_true = vec![0, 1, 1, 0, 1, 0, 1, 0, 1, 0];
let result = accuracy(&y_pred, &y_true);
assert_eq!(result.get("Overall").unwrap(), &0.9);
assert_eq!(result.get("0").unwrap(), &0.8);
assert_eq!(result.get("1").unwrap(), &1.0);
}
#[test]
fn test_confusion_matrix() {
let y_pred = vec![0, 1, 1, 0, 1, 0, 1, 1, 1, 0];
let y_true = vec![0, 1, 1, 0, 1, 0, 1, 0, 1, 0];
let result = confusion_matrix(&y_pred, &y_true);
assert_eq!(result[0][0], 4);
assert_eq!(result[0][1], 1);
assert_eq!(result[1][0], 0);
assert_eq!(result[1][1], 5);
}
}