use burn::module::AutodiffModule;
use burn::prelude::*;
use burn::tensor::backend::AutodiffBackend;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use tsai_data::TSDataLoaders;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfusionMatrix {
pub matrix: Vec<Vec<usize>>,
pub n_classes: usize,
}
impl ConfusionMatrix {
pub fn from_predictions(predictions: &[usize], targets: &[usize], n_classes: usize) -> Self {
let mut matrix = vec![vec![0; n_classes]; n_classes];
for (&pred, &target) in predictions.iter().zip(targets) {
if target < n_classes && pred < n_classes {
matrix[target][pred] += 1;
}
}
Self { matrix, n_classes }
}
pub fn accuracy(&self) -> f32 {
let correct: usize = (0..self.n_classes).map(|i| self.matrix[i][i]).sum();
let total: usize = self.matrix.iter().flatten().sum();
if total == 0 {
0.0
} else {
correct as f32 / total as f32
}
}
pub fn precision(&self, class: usize) -> f32 {
let tp = self.matrix[class][class];
let fp: usize = (0..self.n_classes)
.filter(|&i| i != class)
.map(|i| self.matrix[i][class])
.sum();
if tp + fp == 0 {
0.0
} else {
tp as f32 / (tp + fp) as f32
}
}
pub fn recall(&self, class: usize) -> f32 {
let tp = self.matrix[class][class];
let fn_val: usize = (0..self.n_classes)
.filter(|&i| i != class)
.map(|i| self.matrix[class][i])
.sum();
if tp + fn_val == 0 {
0.0
} else {
tp as f32 / (tp + fn_val) as f32
}
}
pub fn f1(&self, class: usize) -> f32 {
let p = self.precision(class);
let r = self.recall(class);
if p + r == 0.0 {
0.0
} else {
2.0 * p * r / (p + r)
}
}
pub fn macro_f1(&self) -> f32 {
let sum: f32 = (0..self.n_classes).map(|i| self.f1(i)).sum();
sum / self.n_classes as f32
}
pub fn to_table(&self) -> String {
let mut s = String::new();
s.push_str(" ");
for j in 0..self.n_classes {
s.push_str(&format!("{:>7}", format!("P{}", j)));
}
s.push('\n');
for i in 0..self.n_classes {
s.push_str(&format!(" T{:<4} ", i));
for j in 0..self.n_classes {
s.push_str(&format!("{:>7}", self.matrix[i][j]));
}
s.push('\n');
}
s
}
}
#[derive(Debug, Clone)]
pub struct EvaluationResult {
pub predictions: Vec<usize>,
pub targets: Vec<usize>,
pub probabilities: Vec<Vec<f32>>,
pub accuracy: f32,
pub correct: usize,
pub total: usize,
}
impl EvaluationResult {
pub fn confusion_matrix(&self, n_classes: usize) -> ConfusionMatrix {
ConfusionMatrix::from_predictions(&self.predictions, &self.targets, n_classes)
}
pub fn print_summary(&self) {
println!("Evaluation Results:");
println!(" Accuracy: {:.2}%", self.accuracy * 100.0);
println!(" Correct: {} / {}", self.correct, self.total);
}
pub fn print_confusion_matrix(&self, n_classes: usize) {
let cm = self.confusion_matrix(n_classes);
println!("\nConfusion Matrix (T=True, P=Predicted):");
println!("{}", cm.to_table());
println!("Per-class metrics:");
println!("{:<10} {:>10} {:>10} {:>10}", "Class", "Precision", "Recall", "F1");
println!("{}", "-".repeat(42));
for i in 0..n_classes {
println!(
"{:<10} {:>9.2}% {:>9.2}% {:>9.2}%",
i,
cm.precision(i) * 100.0,
cm.recall(i) * 100.0,
cm.f1(i) * 100.0
);
}
println!("{}", "-".repeat(42));
println!(
"{:<10} {:>10} {:>10} {:>9.2}%",
"Macro F1",
"",
"",
cm.macro_f1() * 100.0
);
}
}
pub fn evaluate_classification<B, M, G>(
model: &M,
dls: &TSDataLoaders,
forward_fn: G,
) -> Result<EvaluationResult>
where
B: AutodiffBackend,
M: AutodiffModule<B>,
G: Fn(&M::InnerModule, Tensor<B::InnerBackend, 3>) -> Tensor<B::InnerBackend, 2>,
{
let inner_model = model.clone().valid();
let inner_device: <B::InnerBackend as Backend>::Device = Default::default();
let mut all_predictions = Vec::new();
let mut all_targets = Vec::new();
let mut all_probabilities = Vec::new();
for batch_result in dls.valid().iter::<B::InnerBackend>(&inner_device) {
let batch = batch_result?;
let x = batch.x.inner().clone();
let y = batch.y.expect("Evaluation requires targets");
let [batch_size, _] = y.dims();
let logits = forward_fn(&inner_model, x);
let probs = burn::tensor::activation::softmax(logits.clone(), 1);
let preds = logits.argmax(1).squeeze::<1>(1);
let preds_data: Vec<usize> = {
let data = preds.into_data();
if let Ok(vec) = data.clone().to_vec::<i32>() {
vec.into_iter().map(|x| x as usize).collect()
} else if let Ok(vec) = data.to_vec::<i64>() {
vec.into_iter().map(|x| x as usize).collect()
} else {
panic!("Unsupported prediction data type");
}
};
let targets_data: Vec<f32> = y.reshape([batch_size]).into_data().to_vec().unwrap();
let probs_data: Vec<f32> = probs.into_data().to_vec().unwrap();
let n_classes = probs_data.len() / batch_size;
for i in 0..batch_size {
all_predictions.push(preds_data[i]);
all_targets.push(targets_data[i] as usize);
let prob_start = i * n_classes;
let prob_end = prob_start + n_classes;
all_probabilities.push(probs_data[prob_start..prob_end].to_vec());
}
}
let correct = all_predictions
.iter()
.zip(&all_targets)
.filter(|(p, t)| *p == *t)
.count();
let total = all_predictions.len();
let accuracy = if total > 0 {
correct as f32 / total as f32
} else {
0.0
};
Ok(EvaluationResult {
predictions: all_predictions,
targets: all_targets,
probabilities: all_probabilities,
accuracy,
correct,
total,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_confusion_matrix() {
let predictions = vec![0, 1, 0, 1];
let targets = vec![0, 1, 1, 1];
let cm = ConfusionMatrix::from_predictions(&predictions, &targets, 2);
assert_eq!(cm.matrix[0][0], 1); assert_eq!(cm.matrix[1][1], 2); assert_eq!(cm.matrix[1][0], 1); assert!((cm.accuracy() - 0.75).abs() < 1e-6);
}
#[test]
fn test_evaluation_result() {
let result = EvaluationResult {
predictions: vec![0, 1, 0, 1],
targets: vec![0, 1, 1, 1],
probabilities: vec![
vec![0.8, 0.2],
vec![0.3, 0.7],
vec![0.6, 0.4],
vec![0.2, 0.8],
],
accuracy: 0.75,
correct: 3,
total: 4,
};
let cm = result.confusion_matrix(2);
assert_eq!(cm.accuracy(), 0.75);
}
}