use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
use std::collections::HashMap;
use std::error::Error;
use crate::classification::curves::roc_curve;
use crate::error::{MetricsError, Result};
use crate::visualization::{
MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata, VisualizationOptions,
};
pub(crate) type ROCComputeResult = (Vec<f64>, Vec<f64>, Vec<f64>, Option<f64>);
pub(crate) type ConfusionMatrixValues = (usize, usize, usize, usize);
#[derive(Debug, Clone)]
pub struct InteractiveROCVisualizer<'a, T, S>
where
T: Clone + PartialOrd,
S: Data<Elem = T>,
{
tpr: Option<Vec<f64>>,
fpr: Option<Vec<f64>>,
thresholds: Option<Vec<f64>>,
auc: Option<f64>,
title: String,
show_auc: bool,
show_baseline: bool,
y_true: Option<&'a ArrayBase<S, Ix1>>,
y_score: Option<&'a ArrayBase<S, Ix1>>,
pos_label: Option<T>,
current_threshold_idx: Option<usize>,
show_metrics: bool,
interactive_options: InteractiveOptions,
}
#[derive(Debug, Clone)]
pub struct InteractiveOptions {
pub width: usize,
pub height: usize,
pub show_threshold_slider: bool,
pub show_metric_values: bool,
pub show_confusion_matrix: bool,
pub custom_layout: HashMap<String, String>,
}
impl Default for InteractiveOptions {
fn default() -> Self {
Self {
width: 800,
height: 600,
show_threshold_slider: true,
show_metric_values: true,
show_confusion_matrix: true,
custom_layout: HashMap::new(),
}
}
}
impl<'a, T, S> InteractiveROCVisualizer<'a, T, S>
where
T: Clone + PartialOrd + 'static,
S: Data<Elem = T>,
f64: From<T>,
{
pub fn new(
fpr: Vec<f64>,
tpr: Vec<f64>,
thresholds: Option<Vec<f64>>,
auc: Option<f64>,
) -> Self {
InteractiveROCVisualizer {
tpr: Some(tpr),
fpr: Some(fpr),
thresholds,
auc,
title: "Interactive ROC Curve".to_string(),
show_auc: true,
show_baseline: true,
y_true: None,
y_score: None,
pos_label: None,
current_threshold_idx: None,
show_metrics: true,
interactive_options: InteractiveOptions::default(),
}
}
pub fn from_labels(
y_true: &'a ArrayBase<S, Ix1>,
y_score: &'a ArrayBase<S, Ix1>,
pos_label: Option<T>,
) -> Self {
InteractiveROCVisualizer {
tpr: None,
fpr: None,
thresholds: None,
auc: None,
title: "Interactive ROC Curve".to_string(),
show_auc: true,
show_baseline: true,
y_true: Some(y_true),
y_score: Some(y_score),
pos_label,
current_threshold_idx: None,
show_metrics: true,
interactive_options: InteractiveOptions::default(),
}
}
pub fn with_title(mut self, title: String) -> Self {
self.title = title;
self
}
pub fn with_show_auc(mut self, showauc: bool) -> Self {
self.show_auc = showauc;
self
}
pub fn with_show_baseline(mut self, showbaseline: bool) -> Self {
self.show_baseline = showbaseline;
self
}
pub fn with_auc(mut self, auc: f64) -> Self {
self.auc = Some(auc);
self
}
pub fn with_show_metrics(mut self, showmetrics: bool) -> Self {
self.show_metrics = showmetrics;
self
}
pub fn with_interactive_options(mut self, options: InteractiveOptions) -> Self {
self.interactive_options = options;
self
}
pub fn with_threshold_index(mut self, idx: usize) -> Self {
self.current_threshold_idx = Some(idx);
self
}
pub fn with_threshold_value(mut self, threshold: f64) -> Result<Self> {
let (_, _, thresholds_, _) = self.compute_roc()?;
if thresholds_.is_empty() {
return Err(MetricsError::InvalidInput(
"No thresholds available".to_string(),
));
}
let mut closest_idx = 0;
let mut min_diff = f64::INFINITY;
for (i, &t) in thresholds_.iter().enumerate() {
let diff = (t - threshold).abs();
if diff < min_diff {
min_diff = diff;
closest_idx = i;
}
}
self.current_threshold_idx = Some(closest_idx);
Ok(self)
}
fn compute_roc(&self) -> Result<ROCComputeResult> {
if self.fpr.is_some() && self.tpr.is_some() {
return Ok((
self.fpr.clone().expect("Operation failed"),
self.tpr.clone().expect("Operation failed"),
self.thresholds.clone().unwrap_or_default(),
self.auc,
));
}
if self.y_true.is_none() || self.y_score.is_none() {
return Err(MetricsError::InvalidInput(
"No data provided for ROC curve computation".to_string(),
));
}
let y_true = self.y_true.expect("Operation failed");
let y_score = self.y_score.expect("Operation failed");
let (fpr, tpr, thresholds) = roc_curve(y_true, y_score)?;
let auc = if self.auc.is_none() {
let n = fpr.len();
let mut area = 0.0;
for i in 1..n {
area += (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]) / 2.0;
}
Some(area)
} else {
self.auc
};
Ok((fpr.to_vec(), tpr.to_vec(), thresholds.to_vec(), auc))
}
pub fn calculate_confusion_matrix(
&self,
threshold_idx: usize,
) -> Result<ConfusionMatrixValues> {
if self.y_true.is_none() || self.y_score.is_none() {
return Err(MetricsError::InvalidInput(
"Original data required for confusion matrix calculation".to_string(),
));
}
let (_, _, thresholds_, _) = self.compute_roc()?;
if threshold_idx >= thresholds_.len() {
return Err(MetricsError::InvalidArgument(
"Threshold index out of range".to_string(),
));
}
let threshold = thresholds_[threshold_idx];
let y_true = self.y_true.expect("Operation failed");
let y_score = self.y_score.expect("Operation failed");
let mut tp = 0;
let mut fp = 0;
let mut tn = 0;
let mut fn_ = 0;
let pos_label_f64 = match &self.pos_label {
Some(label) => f64::from(label.clone()),
None => 1.0, };
for i in 0..y_true.len() {
let true_val = f64::from(y_true[i].clone());
let score = f64::from(y_score[i].clone());
let pred = if score >= threshold {
pos_label_f64
} else {
0.0
};
if pred == pos_label_f64 && true_val == pos_label_f64 {
tp += 1;
} else if pred == pos_label_f64 && true_val != pos_label_f64 {
fp += 1;
} else if pred != pos_label_f64 && true_val != pos_label_f64 {
tn += 1;
} else {
fn_ += 1;
}
}
Ok((tp, fp, tn, fn_))
}
pub fn calculate_metrics(&self, thresholdidx: usize) -> Result<HashMap<String, f64>> {
let (tp, fp, tn, fn_) = self.calculate_confusion_matrix(thresholdidx)?;
let mut metrics = HashMap::new();
let accuracy = (tp + tn) as f64 / (tp + fp + tn + fn_) as f64;
metrics.insert("accuracy".to_string(), accuracy);
let precision = if tp + fp > 0 {
tp as f64 / (tp + fp) as f64
} else {
0.0
};
metrics.insert("precision".to_string(), precision);
let recall = if tp + fn_ > 0 {
tp as f64 / (tp + fn_) as f64
} else {
0.0
};
metrics.insert("recall".to_string(), recall);
let specificity = if tn + fp > 0 {
tn as f64 / (tn + fp) as f64
} else {
0.0
};
metrics.insert("specificity".to_string(), specificity);
let f1 = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
metrics.insert("f1_score".to_string(), f1);
let (_, _, thresholds_, _) = self.compute_roc()?;
metrics.insert("threshold".to_string(), thresholds_[thresholdidx]);
Ok(metrics)
}
pub fn get_current_threshold_idx(&self) -> Result<usize> {
let (_, _, thresholds_, _) = self.compute_roc()?;
if thresholds_.is_empty() {
return Err(MetricsError::InvalidInput(
"No thresholds available".to_string(),
));
}
match self.current_threshold_idx {
Some(idx) if idx < thresholds_.len() => Ok(idx),
_ => Ok(thresholds_.len() / 2), }
}
}
impl<T, S> MetricVisualizer for InteractiveROCVisualizer<'_, T, S>
where
T: Clone + PartialOrd + 'static,
S: Data<Elem = T>,
f64: From<T>,
{
fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
let (fpr, tpr, thresholds, auc) = self
.compute_roc()
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
let mut data = VisualizationData::new();
data.x = fpr.clone();
data.y = tpr.clone();
data.add_auxiliary_data("thresholds".to_string(), thresholds.clone());
if let Some(auc_val) = auc {
data.add_auxiliary_metadata("auc".to_string(), auc_val.to_string());
}
if let Ok(threshold_idx) = self.get_current_threshold_idx() {
let current_point_x = vec![fpr[threshold_idx]];
let current_point_y = vec![tpr[threshold_idx]];
data.add_auxiliary_data("current_point_x".to_string(), current_point_x);
data.add_auxiliary_data("current_point_y".to_string(), current_point_y);
data.add_auxiliary_metadata(
"current_threshold".to_string(),
thresholds[threshold_idx].to_string(),
);
if self.show_metrics {
if let Ok(metrics) = self.calculate_metrics(threshold_idx) {
for (name, value) in metrics {
data.add_auxiliary_metadata(format!("metric_{name}"), value.to_string());
}
}
}
}
data.add_auxiliary_metadata(
"interactive_width".to_string(),
self.interactive_options.width.to_string(),
);
data.add_auxiliary_metadata(
"interactive_height".to_string(),
self.interactive_options.height.to_string(),
);
data.add_auxiliary_metadata(
"show_threshold_slider".to_string(),
self.interactive_options.show_threshold_slider.to_string(),
);
data.add_auxiliary_metadata(
"show_metric_values".to_string(),
self.interactive_options.show_metric_values.to_string(),
);
data.add_auxiliary_metadata(
"show_confusion_matrix".to_string(),
self.interactive_options.show_confusion_matrix.to_string(),
);
for (key, value) in &self.interactive_options.custom_layout {
data.add_auxiliary_metadata(format!("layout_{key}"), value.clone());
}
if self.show_baseline {
data.add_auxiliary_data("baseline_x".to_string(), vec![0.0, 1.0]);
data.add_auxiliary_data("baseline_y".to_string(), vec![0.0, 1.0]);
}
let mut series_names = Vec::new();
if self.show_auc && auc.is_some() {
series_names.push(format!(
"ROC curve (AUC = {:.3})",
auc.expect("Operation failed")
));
} else {
series_names.push("ROC curve".to_string());
}
if self.show_baseline {
series_names.push("Random classifier".to_string());
}
series_names.push("Current threshold".to_string());
data.add_series_names(series_names);
Ok(data)
}
fn get_metadata(&self) -> VisualizationMetadata {
let mut metadata = VisualizationMetadata::new(self.title.clone());
metadata.set_plot_type(PlotType::Line);
metadata.set_x_label("False Positive Rate".to_string());
metadata.set_y_label("True Positive Rate".to_string());
metadata.set_description("Interactive ROC curve showing the trade-off between true positive rate and false positive rate. Adjust the threshold to see performance metrics.".to_string());
metadata
}
}
#[allow(dead_code)]
pub fn interactive_roc_curve_visualization(
fpr: Vec<f64>,
tpr: Vec<f64>,
thresholds: Option<Vec<f64>>,
auc: Option<f64>,
) -> InteractiveROCVisualizer<'static, f64, scirs2_core::ndarray::OwnedRepr<f64>> {
InteractiveROCVisualizer::new(fpr, tpr, thresholds, auc)
}
#[allow(dead_code)]
pub fn interactive_roc_curve_from_labels<'a, T, S>(
y_true: &'a ArrayBase<S, Ix1>,
y_score: &'a ArrayBase<S, Ix1>,
pos_label: Option<T>,
) -> InteractiveROCVisualizer<'a, T, S>
where
T: Clone + PartialOrd + 'static,
S: Data<Elem = T>,
f64: From<T>,
{
InteractiveROCVisualizer::from_labels(y_true, y_score, pos_label)
}