use crate::grammar::chart::Chart;
use crate::grammar::layer::{Layer, MarkType};
use crate::new_theme::NewTheme;
pub trait RocCurveExt {
fn roc_figure(&self) -> Chart;
fn roc_figure_with_theme(&self, theme: NewTheme) -> Chart;
}
impl RocCurveExt for scry_learn::metrics::RocCurve {
fn roc_figure(&self) -> Chart {
self.roc_figure_with_theme(NewTheme::default())
}
fn roc_figure_with_theme(&self, theme: NewTheme) -> Chart {
let auc_label = if self.auc.is_nan() {
"ROC (AUC = N/A)".to_string()
} else {
format!("ROC (AUC = {:.3})", self.auc)
};
let roc_layer = Layer::new(MarkType::Line)
.with_x(self.fpr.clone())
.with_y(self.tpr.clone())
.with_label("ROC Curve");
let diag_layer = Layer::new(MarkType::Line)
.with_x(vec![0.0, 1.0])
.with_y(vec![0.0, 1.0])
.with_label("Random Classifier");
Chart::new()
.layer(roc_layer)
.layer(diag_layer)
.title(auc_label)
.x_label("False Positive Rate")
.y_label("True Positive Rate")
.x_domain(0.0, 1.0)
.y_domain(0.0, 1.0)
.theme(theme)
.size(600.0, 600.0)
}
}
pub fn roc_curve_figure(roc: &scry_learn::metrics::RocCurve) -> Chart {
roc.roc_figure()
}