use crate::error::{NeuralError, Result};
use crate::utils::colors::{
colored_metric_cell, colorize, stylize, Color, ColorOptions, Style, RESET,
};
use crate::utils::evaluation::helpers::draw_line_with_coords;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::Float;
use std::fmt::{Debug, Display};
pub struct ROCCurve<F: Float + Debug + Display> {
pub fpr: Array1<F>,
pub tpr: Array1<F>,
pub thresholds: Array1<F>,
pub auc: F,
}
impl<F: Float + Debug + Display> ROCCurve<F> {
pub fn new(y_true: &ArrayView1<usize>, yscore: &ArrayView1<F>) -> Result<Self> {
if y_true.len() != yscore.len() {
return Err(NeuralError::ValidationError(
"Labels and scores must have the same length".to_string(),
));
}
for &label in y_true.iter() {
if label != 0 && label != 1 {
return Err(NeuralError::ValidationError(
"Labels must be binary (0 or 1)".to_string(),
));
}
}
let mut score_label_pairs: Vec<(F, usize)> = yscore
.iter()
.zip(y_true.iter())
.map(|(&_score, &label)| (_score, label))
.collect();
score_label_pairs
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let n_pos = y_true.iter().filter(|&&label| label == 1).count();
let n_neg = y_true.len() - n_pos;
if n_pos == 0 || n_neg == 0 {
return Err(NeuralError::ValidationError(
"Both positive and negative samples are required".to_string(),
));
}
let n_thresholds = score_label_pairs.len() + 1;
let mut fpr = Array1::zeros(n_thresholds);
let mut tpr = Array1::zeros(n_thresholds);
let mut thresholds = Array1::zeros(n_thresholds);
thresholds[0] = F::infinity();
let mut tp = 0;
let mut fp = 0;
for i in 0..score_label_pairs.len() {
let (score, label) = score_label_pairs[i];
if label == 1 {
tp += 1;
} else {
fp += 1;
}
thresholds[i + 1] = score;
tpr[i + 1] = F::from(tp).expect("Failed to convert to float")
/ F::from(n_pos).expect("Failed to convert to float");
fpr[i + 1] = F::from(fp).expect("Failed to convert to float")
/ F::from(n_neg).expect("Failed to convert to float");
}
let mut auc = F::zero();
for i in 0..fpr.len() - 1 {
auc = auc
+ (fpr[i + 1] - fpr[i])
* (tpr[i] + tpr[i + 1])
* F::from(0.5).expect("Failed to convert constant to float");
}
Ok(ROCCurve {
fpr,
tpr,
thresholds,
auc,
})
}
pub fn to_ascii(&self, title: Option<&str>, width: usize, height: usize) -> String {
self.to_ascii_with_options(title, width, height, &ColorOptions::default())
}
pub fn to_ascii_with_options(
&self,
title: Option<&str>,
width: usize,
height: usize,
color_options: &ColorOptions,
) -> String {
let mut result = String::with_capacity(width * height * 2);
if let Some(titletext) = title {
if color_options.enabled {
let styled_title = stylize(titletext, Style::Bold);
let auc_value = self.auc.to_f64().unwrap_or(0.0);
let colored_auc =
colored_metric_cell(format!("{:.3}", self.auc), auc_value, color_options);
result.push_str(&format!("{styled_title} (AUC = {colored_auc})\n\n"));
} else {
result.push_str(&format!("{} (AUC = {:.3})\n\n", titletext, self.auc));
}
} else if color_options.enabled {
let styled_title = stylize("ROC Curve", Style::Bold);
let auc_value = self.auc.to_f64().unwrap_or(0.0);
let colored_auc =
colored_metric_cell(format!("{:.3}", self.auc), auc_value, color_options);
result.push_str(&format!("{styled_title} (AUC = {colored_auc})\n\n"));
} else {
result.push_str(&format!("ROC Curve (AUC = {:.3})\n\n", self.auc));
}
let mut grid = vec![vec![' '; width]; height];
for i in 0..std::cmp::min(width, height) {
let x = i;
let y = height - 1 - i * (height - 1) / (width - 1);
if x < width && y < height {
grid[y][x] = '.';
}
}
let mut prev_x = 0;
let mut prev_y = height - 1; for i in 1..self.fpr.len() {
let x = (self.fpr[i].to_f64().expect("Operation failed") * (width - 1) as f64).round()
as usize;
let y = height
- 1
- (self.tpr[i].to_f64().expect("Operation failed") * (height - 1) as f64).round()
as usize;
if x != prev_x || y != prev_y {
for (line_x, line_y) in
draw_line_with_coords(prev_x, prev_y, x, y, Some(width), Some(height))
{
grid[line_y][line_x] = '●';
}
prev_x = x;
prev_y = y;
}
}
for (y, row) in grid.iter().enumerate() {
if y == height - 1 {
if color_options.enabled {
let fg_code = Color::BrightCyan.fg_code();
result.push_str(&format!("{fg_code}0.0{RESET} |"));
} else {
result.push_str("0.0 |");
}
} else if y == 0 {
if color_options.enabled {
let fg_code = Color::BrightCyan.fg_code();
result.push_str(&format!("{fg_code}1.0{RESET} |"));
} else {
result.push_str("1.0 |");
}
} else if y == height / 2 {
if color_options.enabled {
let fg_code = Color::BrightCyan.fg_code();
result.push_str(&format!("{fg_code}0.5{RESET} |"));
} else {
result.push_str("0.5 |");
}
} else {
result.push_str(" |");
}
for char in row.iter().take(width) {
if color_options.enabled {
match char {
'●' => {
result.push_str(&colorize("●", Color::BrightGreen));
}
'.' => {
result.push_str(&colorize(".", Color::BrightBlack));
}
_ => result.push(*char),
}
} else {
result.push(*char);
}
}
result.push('\n');
}
result.push_str(" +");
result.push_str(&"-".repeat(width));
result.push('\n');
result.push_str(" ");
if color_options.enabled {
result.push_str(&colorize("0.0", Color::BrightCyan));
result.push_str(&" ".repeat(width - 6));
result.push_str(&colorize("1.0", Color::BrightCyan));
} else {
result.push_str("0.0");
result.push_str(&" ".repeat(width - 6));
result.push_str("1.0");
}
result.push('\n');
if color_options.enabled {
result.push_str(&format!(
" {}\n",
stylize("False Positive Rate (FPR)", Style::Bold)
));
} else {
result.push_str(" False Positive Rate (FPR)\n");
}
if color_options.enabled {
result.push_str(&format!(
" {} ROC curve {} Random classifier\n",
colorize("●", Color::BrightGreen),
colorize(".", Color::BrightBlack)
));
}
result
}
}