use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
use std::error::Error;
use super::{ColorMap, MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata};
use crate::classification::confusion_matrix;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct ConfusionMatrixVisualizer<'a, T, S>
where
T: Clone + PartialEq + std::fmt::Debug + std::hash::Hash + Ord + scirs2_core::numeric::NumCast,
S: Data<Elem = T>,
{
matrix: Array2<f64>,
labels: Option<Vec<String>>,
title: String,
normalize: bool,
color_map: ColorMap,
includetext: bool,
y_true: Option<&'a ArrayBase<S, Ix2>>,
y_pred: Option<&'a ArrayBase<S, Ix2>>,
}
impl<'a, T, S> ConfusionMatrixVisualizer<'a, T, S>
where
T: Clone
+ PartialEq
+ std::fmt::Debug
+ std::hash::Hash
+ Ord
+ scirs2_core::numeric::NumCast
+ 'static,
S: Data<Elem = T>,
{
pub fn new(matrix: Array2<f64>, labels: Option<Vec<String>>) -> Self {
ConfusionMatrixVisualizer {
matrix,
labels,
title: "Confusion Matrix".to_string(),
normalize: false,
color_map: ColorMap::BlueRed,
includetext: true,
y_true: None,
y_pred: None,
}
}
pub fn from_labels(
y_true: &'a ArrayBase<S, Ix2>,
y_pred: &'a ArrayBase<S, Ix2>,
labels: Option<Vec<String>>,
) -> Result<Self> {
Ok(ConfusionMatrixVisualizer {
matrix: Array2::zeros((0, 0)),
labels,
title: "Confusion Matrix".to_string(),
normalize: false,
color_map: ColorMap::BlueRed,
includetext: true,
y_true: Some(y_true),
y_pred: Some(y_pred),
})
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn with_title(mut self, title: String) -> Self {
self.title = title;
self
}
pub fn with_color_map(mut self, colormap: ColorMap) -> Self {
self.color_map = colormap;
self
}
pub fn with_includetext(mut self, includetext: bool) -> Self {
self.includetext = includetext;
self
}
fn get_matrix(&self) -> Result<Array2<f64>> {
if self.y_true.is_some() && self.y_pred.is_some() {
let y_true = self.y_true.expect("Operation failed");
let y_pred = self.y_pred.expect("Operation failed");
let (cm, _labels) = confusion_matrix(y_true, y_pred, None)?;
if self.normalize {
let mut normalized = Array2::zeros(cm.dim());
for (i, row) in cm.outer_iter().enumerate() {
let row_sum: f64 = row.sum() as f64;
if row_sum > 0.0 {
for (j, &val) in row.iter().enumerate() {
normalized[[i, j]] = val as f64 / row_sum;
}
}
}
Ok(normalized)
} else {
let float_cm = cm.mapv(|x| x as f64);
Ok(float_cm)
}
} else {
if self.normalize {
let mut normalized = Array2::zeros(self.matrix.dim());
for (i, row) in self.matrix.outer_iter().enumerate() {
let row_sum = row.sum();
if row_sum > 0.0 {
for (j, &val) in row.iter().enumerate() {
normalized[[i, j]] = val / row_sum;
}
}
}
Ok(normalized)
} else {
Ok(self.matrix.clone())
}
}
}
}
impl<T, S> MetricVisualizer for ConfusionMatrixVisualizer<'_, T, S>
where
T: Clone
+ PartialEq
+ std::fmt::Debug
+ std::hash::Hash
+ Ord
+ scirs2_core::numeric::NumCast
+ 'static,
S: Data<Elem = T>,
{
fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
let matrix = self
.get_matrix()
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
let n_classes = matrix.shape()[0];
let mut z = Vec::with_capacity(n_classes);
for i in 0..n_classes {
let mut row = Vec::with_capacity(n_classes);
for j in 0..n_classes {
row.push(matrix[[i, j]]);
}
z.push(row);
}
let x = (0..n_classes).map(|i| i as f64).collect::<Vec<_>>();
let y = (0..n_classes).map(|i| i as f64).collect::<Vec<_>>();
let x_labels = if let Some(labels) = &self.labels {
Some(labels.clone())
} else {
Some((0..n_classes).map(|i| i.to_string()).collect())
};
let y_labels = x_labels.clone();
Ok(VisualizationData {
x,
y,
z: Some(z),
series_names: None,
x_labels,
y_labels,
auxiliary_data: std::collections::HashMap::new(),
auxiliary_metadata: std::collections::HashMap::new(),
series: std::collections::HashMap::new(),
})
}
fn get_metadata(&self) -> VisualizationMetadata {
VisualizationMetadata {
title: self.title.clone(),
x_label: "Predicted label".to_string(),
y_label: "True label".to_string(),
plot_type: PlotType::Heatmap,
description: Some(
"Confusion matrix showing the counts of true vs. predicted class labels"
.to_string(),
),
}
}
}
#[allow(dead_code)]
pub fn confusion_matrix_visualization(
matrix: Array2<f64>,
labels: Option<Vec<String>>,
normalize: bool,
) -> Box<dyn MetricVisualizer> {
#[allow(dead_code)]
struct F64ConfusionMatrixVisualizer {
matrix: Array2<f64>,
labels: Option<Vec<String>>,
title: String,
normalize: bool,
color_map: ColorMap,
includetext: bool,
}
impl MetricVisualizer for F64ConfusionMatrixVisualizer {
fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
let matrix = if self.normalize {
let mut normalized = Array2::zeros(self.matrix.dim());
for (i, row) in self.matrix.outer_iter().enumerate() {
let row_sum: f64 = row.sum();
if row_sum > 0.0 {
for (j, &val) in row.iter().enumerate() {
normalized[[i, j]] = val / row_sum;
}
}
}
normalized
} else {
self.matrix.clone()
};
let n_classes = matrix.shape()[0];
let mut z = Vec::with_capacity(n_classes);
for i in 0..n_classes {
let mut row = Vec::with_capacity(n_classes);
for j in 0..n_classes {
row.push(matrix[[i, j]]);
}
z.push(row);
}
let x = (0..n_classes).map(|i| i as f64).collect::<Vec<_>>();
let y = (0..n_classes).map(|i| i as f64).collect::<Vec<_>>();
let x_labels = if let Some(labels) = &self.labels {
Some(labels.clone())
} else {
Some((0..n_classes).map(|i| i.to_string()).collect())
};
let y_labels = x_labels.clone();
Ok(VisualizationData {
x,
y,
z: Some(z),
series_names: None,
x_labels,
y_labels,
auxiliary_data: std::collections::HashMap::new(),
auxiliary_metadata: std::collections::HashMap::new(),
series: std::collections::HashMap::new(),
})
}
fn get_metadata(&self) -> VisualizationMetadata {
VisualizationMetadata {
title: self.title.clone(),
x_label: "Predicted label".to_string(),
y_label: "True label".to_string(),
plot_type: PlotType::Heatmap,
description: Some(
"Confusion matrix showing the counts of true vs. predicted class labels"
.to_string(),
),
}
}
}
Box::new(F64ConfusionMatrixVisualizer {
matrix,
labels,
title: "Confusion Matrix".to_string(),
normalize,
color_map: ColorMap::BlueRed,
includetext: true,
})
}
#[allow(dead_code)]
pub fn confusion_matrix_from_labels<'a, T, S>(
y_true: &'a ArrayBase<S, Ix2>,
y_pred: &'a ArrayBase<S, Ix2>,
labels: Option<Vec<String>>,
normalize: bool,
) -> Result<Box<dyn MetricVisualizer + 'a>>
where
T: Clone
+ PartialEq
+ std::fmt::Debug
+ std::hash::Hash
+ Ord
+ scirs2_core::numeric::NumCast
+ 'static,
S: Data<Elem = T>,
{
let visualizer = ConfusionMatrixVisualizer::from_labels(y_true, y_pred, labels)?;
Ok(Box::new(visualizer.with_normalize(normalize)))
}