use core::f64;
use core::marker::PhantomData;
use super::MetricMetadata;
use super::state::{FormatOptions, NumericMetricState};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{ElementConversion, Int, Tensor};
#[derive(Clone)]
pub struct AurocMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
_b: PhantomData<B>,
}
#[derive(new)]
pub struct AurocInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Default for AurocMetric<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> AurocMetric<B> {
pub fn new() -> Self {
Self {
name: MetricName::new("AUROC".to_string()),
state: Default::default(),
_b: PhantomData,
}
}
fn binary_auroc(&self, probabilities: &Tensor<B, 1>, targets: &Tensor<B, 1, Int>) -> f64 {
let n = targets.dims()[0];
let n_pos = targets.clone().sum().into_scalar().elem::<u64>() as usize;
if n_pos == 0 || n_pos == n {
if n_pos == 0 {
log::warn!("Metric cannot be computed because all target values are negative.")
} else {
log::warn!("Metric cannot be computed because all target values are positive.")
}
return 0.0;
}
let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]);
let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]);
let valid_pairs = pos_mask * neg_mask;
let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n);
let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n);
let correct_order = prob_i.clone().greater(prob_j.clone()).int();
let ties = prob_i.equal(prob_j).int();
let num_pairs = valid_pairs.clone().sum().into_scalar().elem::<f64>();
let correct_pairs = (correct_order * valid_pairs.clone())
.sum()
.into_scalar()
.elem::<f64>();
let tied_pairs = (ties * valid_pairs).sum().into_scalar().elem::<f64>();
(correct_pairs + 0.5 * tied_pairs) / num_pairs
}
}
impl<B: Backend> Metric for AurocMetric<B> {
type Input = AurocInput<B>;
fn update(&mut self, input: &AurocInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
let [batch_size, num_classes] = input.outputs.dims();
assert_eq!(
num_classes, 2,
"Currently only binary classification is supported"
);
let probabilities = {
let exponents = input.outputs.clone().exp();
let sum = exponents.clone().sum_dim(1);
(exponents / sum)
.select(1, Tensor::arange(1..2, &input.outputs.device()))
.squeeze_dim(1)
};
let area_under_curve = self.binary_auroc(&probabilities, &input.targets);
self.state.update(
100.0 * area_under_curve,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
}
impl<B: Backend> Numeric for AurocMetric<B> {
fn value(&self) -> super::NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> super::NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_auroc() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.9], [0.7, 0.3], [0.6, 0.4], [0.2, 0.8], ],
&device,
),
Tensor::from_data([1, 0, 0, 1], &device), );
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 100.0);
}
#[test]
fn test_auroc_perfect_separation() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
Tensor::from_data([1, 0, 0, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 100.0); }
#[test]
fn test_auroc_random() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.5, 0.5], [0.5, 0.5],
[0.5, 0.5],
[0.5, 0.5],
],
&device,
),
Tensor::from_data([1, 0, 0, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 50.0);
}
#[test]
fn test_auroc_all_one_class() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.9], [0.2, 0.8],
[0.3, 0.7],
[0.4, 0.6],
],
&device,
),
Tensor::from_data([1, 1, 1, 1], &device), );
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 0.0);
}
#[test]
#[should_panic(expected = "Currently only binary classification is supported")]
fn test_auroc_multiclass_error() {
let device = Default::default();
let mut metric = AurocMetric::<TestBackend>::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.2, 0.7], [0.3, 0.5, 0.2],
],
&device,
),
Tensor::from_data([2, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
}
}