use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
use std::error::Error;
use super::{MetricVisualizer, PlotType, VisualizationData, VisualizationMetadata};
use crate::classification::curves::roc_curve;
use crate::error::{MetricsError, Result};
pub(crate) type ROCComputeResult = (Vec<f64>, Vec<f64>, Vec<f64>, Option<f64>);
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct ROCCurveVisualizer<'a, T, S>
where
T: Clone + PartialOrd,
S: Data<Elem = T>,
{
tpr: Option<Vec<f64>>,
fpr: Option<Vec<f64>>,
thresholds: Option<Vec<f64>>,
auc: Option<f64>,
title: String,
show_auc: bool,
show_baseline: bool,
y_true: Option<&'a ArrayBase<S, Ix1>>,
y_score: Option<&'a ArrayBase<S, Ix1>>,
pos_label: Option<T>,
}
impl<'a, T, S> ROCCurveVisualizer<'a, T, S>
where
T: Clone + PartialOrd + 'static,
S: Data<Elem = T>,
f64: From<T>,
{
pub fn new(
fpr: Vec<f64>,
tpr: Vec<f64>,
thresholds: Option<Vec<f64>>,
auc: Option<f64>,
) -> Self {
ROCCurveVisualizer {
tpr: Some(tpr),
fpr: Some(fpr),
thresholds,
auc,
title: "ROC Curve".to_string(),
show_auc: true,
show_baseline: true,
y_true: None,
y_score: None,
pos_label: None,
}
}
pub fn from_labels(
y_true: &'a ArrayBase<S, Ix1>,
y_score: &'a ArrayBase<S, Ix1>,
pos_label: Option<T>,
) -> Self {
ROCCurveVisualizer {
tpr: None,
fpr: None,
thresholds: None,
auc: None,
title: "ROC Curve".to_string(),
show_auc: true,
show_baseline: true,
y_true: Some(y_true),
y_score: Some(y_score),
pos_label,
}
}
pub fn with_title(mut self, title: String) -> Self {
self.title = title;
self
}
pub fn with_show_auc(mut self, showauc: bool) -> Self {
self.show_auc = showauc;
self
}
pub fn with_show_baseline(mut self, showbaseline: bool) -> Self {
self.show_baseline = showbaseline;
self
}
pub fn with_auc(mut self, auc: f64) -> Self {
self.auc = Some(auc);
self
}
fn compute_roc(&self) -> Result<ROCComputeResult> {
if self.fpr.is_some() && self.tpr.is_some() {
return Ok((
self.fpr.clone().expect("Operation failed"),
self.tpr.clone().expect("Operation failed"),
self.thresholds.clone().unwrap_or_default(),
self.auc,
));
}
if self.y_true.is_none() || self.y_score.is_none() {
return Err(MetricsError::InvalidInput(
"No data provided for ROC curve computation".to_string(),
));
}
let y_true = self.y_true.expect("Operation failed");
let y_score = self.y_score.expect("Operation failed");
let (fpr, tpr, thresholds) = roc_curve(y_true, y_score)?;
let auc = if self.auc.is_none() {
let n = fpr.len();
let mut area = 0.0;
for i in 1..n {
area += (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]) / 2.0;
}
Some(area)
} else {
self.auc
};
Ok((fpr.to_vec(), tpr.to_vec(), thresholds.to_vec(), auc))
}
}
impl<T, S> MetricVisualizer for ROCCurveVisualizer<'_, T, S>
where
T: Clone + PartialOrd + 'static,
S: Data<Elem = T>,
f64: From<T>,
{
fn prepare_data(&self) -> std::result::Result<VisualizationData, Box<dyn Error>> {
let (fpr, tpr_, thresholds, auc) = self
.compute_roc()
.map_err(|e| Box::new(e) as Box<dyn Error>)?;
let mut x = fpr;
let mut y = tpr_;
if self.show_baseline {
x.push(0.0);
x.push(1.0);
y.push(0.0);
y.push(1.0);
}
let mut series_names = Vec::new();
if self.show_auc && auc.is_some() {
series_names.push(format!(
"ROC curve (AUC = {:.3})",
auc.expect("Operation failed")
));
} else {
series_names.push("ROC curve".to_string());
}
if self.show_baseline {
series_names.push("Random classifier".to_string());
}
Ok(VisualizationData {
x,
y,
z: None,
series_names: Some(series_names),
x_labels: None,
y_labels: None,
auxiliary_data: std::collections::HashMap::new(),
auxiliary_metadata: std::collections::HashMap::new(),
series: std::collections::HashMap::new(),
})
}
fn get_metadata(&self) -> VisualizationMetadata {
VisualizationMetadata {
title: self.title.clone(),
x_label: "False Positive Rate".to_string(),
y_label: "True Positive Rate".to_string(),
plot_type: PlotType::Line,
description: Some("ROC curve showing the trade-off between true positive rate and false positive rate".to_string()),
}
}
}
#[allow(dead_code)]
pub fn roc_curve_visualization(
fpr: Vec<f64>,
tpr: Vec<f64>,
thresholds: Option<Vec<f64>>,
auc: Option<f64>,
) -> ROCCurveVisualizer<'static, f64, scirs2_core::ndarray::OwnedRepr<f64>> {
ROCCurveVisualizer::new(fpr, tpr, thresholds, auc)
}
#[allow(dead_code)]
pub fn roc_curve_from_labels<'a, T, S>(
y_true: &'a ArrayBase<S, Ix1>,
y_score: &'a ArrayBase<S, Ix1>,
pos_label: Option<T>,
) -> ROCCurveVisualizer<'a, T, S>
where
T: Clone + PartialOrd + 'static,
S: Data<Elem = T>,
f64: From<T>,
{
ROCCurveVisualizer::from_labels(y_true, y_score, pos_label)
}