1use core::f64;
2use core::marker::PhantomData;
3
4use super::state::{FormatOptions, NumericMetricState};
5use super::{MetricEntry, MetricMetadata};
6use crate::metric::{Metric, MetricName, Numeric};
7use burn_core::tensor::backend::Backend;
8use burn_core::tensor::{ElementConversion, Int, Tensor};
9
10#[derive(Clone)]
12pub struct AurocMetric<B: Backend> {
13 name: MetricName,
14 state: NumericMetricState,
15 _b: PhantomData<B>,
16}
17
18#[derive(new)]
20pub struct AurocInput<B: Backend> {
21 outputs: Tensor<B, 2>,
22 targets: Tensor<B, 1, Int>,
23}
24
25impl<B: Backend> Default for AurocMetric<B> {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl<B: Backend> AurocMetric<B> {
32 pub fn new() -> Self {
34 Self {
35 name: MetricName::new("AUROC".to_string()),
36 state: Default::default(),
37 _b: PhantomData,
38 }
39 }
40
41 fn binary_auroc(&self, probabilities: &Tensor<B, 1>, targets: &Tensor<B, 1, Int>) -> f64 {
42 let n = targets.dims()[0];
43
44 let n_pos = targets.clone().sum().into_scalar().elem::<u64>() as usize;
45
46 if n_pos == 0 || n_pos == n {
48 if n_pos == 0 {
49 log::warn!("Metric cannot be computed because all target values are negative.")
50 } else {
51 log::warn!("Metric cannot be computed because all target values are positive.")
52 }
53 return 0.0;
54 }
55
56 let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]);
57 let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]);
58
59 let valid_pairs = pos_mask * neg_mask;
60
61 let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n);
62 let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n);
63
64 let correct_order = prob_i.clone().greater(prob_j.clone()).int();
65
66 let ties = prob_i.equal(prob_j).int();
67
68 let num_pairs = valid_pairs.clone().sum().into_scalar().elem::<f64>();
70 let correct_pairs = (correct_order * valid_pairs.clone())
71 .sum()
72 .into_scalar()
73 .elem::<f64>();
74 let tied_pairs = (ties * valid_pairs).sum().into_scalar().elem::<f64>();
75
76 (correct_pairs + 0.5 * tied_pairs) / num_pairs
77 }
78}
79
80impl<B: Backend> Metric for AurocMetric<B> {
81 type Input = AurocInput<B>;
82
83 fn update(&mut self, input: &AurocInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
84 let [batch_size, num_classes] = input.outputs.dims();
85
86 assert_eq!(
87 num_classes, 2,
88 "Currently only binary classification is supported"
89 );
90
91 let probabilities = {
92 let exponents = input.outputs.clone().exp();
93 let sum = exponents.clone().sum_dim(1);
94 (exponents / sum)
95 .select(1, Tensor::arange(1..2, &input.outputs.device()))
96 .squeeze_dim(1)
97 };
98
99 let area_under_curve = self.binary_auroc(&probabilities, &input.targets);
100
101 self.state.update(
102 100.0 * area_under_curve,
103 batch_size,
104 FormatOptions::new(self.name()).unit("%").precision(2),
105 )
106 }
107
108 fn clear(&mut self) {
109 self.state.reset()
110 }
111
112 fn name(&self) -> MetricName {
113 self.name.clone()
114 }
115}
116
117impl<B: Backend> Numeric for AurocMetric<B> {
118 fn value(&self) -> super::NumericEntry {
119 self.state.value()
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::TestBackend;
127
128 #[test]
129 fn test_auroc() {
130 let device = Default::default();
131 let mut metric = AurocMetric::<TestBackend>::new();
132
133 let input = AurocInput::new(
134 Tensor::from_data(
135 [
136 [0.1, 0.9], [0.7, 0.3], [0.6, 0.4], [0.2, 0.8], ],
141 &device,
142 ),
143 Tensor::from_data([1, 0, 0, 1], &device), );
145
146 let _entry = metric.update(&input, &MetricMetadata::fake());
147 assert_eq!(metric.value().current(), 100.0);
148 }
149
150 #[test]
151 fn test_auroc_perfect_separation() {
152 let device = Default::default();
153 let mut metric = AurocMetric::<TestBackend>::new();
154
155 let input = AurocInput::new(
156 Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
157 Tensor::from_data([1, 0, 0, 1], &device),
158 );
159
160 let _entry = metric.update(&input, &MetricMetadata::fake());
161 assert_eq!(metric.value().current(), 100.0); }
163
164 #[test]
165 fn test_auroc_random() {
166 let device = Default::default();
167 let mut metric = AurocMetric::<TestBackend>::new();
168
169 let input = AurocInput::new(
170 Tensor::from_data(
171 [
172 [0.5, 0.5], [0.5, 0.5],
174 [0.5, 0.5],
175 [0.5, 0.5],
176 ],
177 &device,
178 ),
179 Tensor::from_data([1, 0, 0, 1], &device),
180 );
181
182 let _entry = metric.update(&input, &MetricMetadata::fake());
183 assert_eq!(metric.value().current(), 50.0);
184 }
185
186 #[test]
187 fn test_auroc_all_one_class() {
188 let device = Default::default();
189 let mut metric = AurocMetric::<TestBackend>::new();
190
191 let input = AurocInput::new(
192 Tensor::from_data(
193 [
194 [0.1, 0.9], [0.2, 0.8],
196 [0.3, 0.7],
197 [0.4, 0.6],
198 ],
199 &device,
200 ),
201 Tensor::from_data([1, 1, 1, 1], &device), );
203
204 let _entry = metric.update(&input, &MetricMetadata::fake());
205 assert_eq!(metric.value().current(), 0.0);
206 }
207
208 #[test]
209 #[should_panic(expected = "Currently only binary classification is supported")]
210 fn test_auroc_multiclass_error() {
211 let device = Default::default();
212 let mut metric = AurocMetric::<TestBackend>::new();
213
214 let input = AurocInput::new(
215 Tensor::from_data(
216 [
217 [0.1, 0.2, 0.7], [0.3, 0.5, 0.2],
219 ],
220 &device,
221 ),
222 Tensor::from_data([2, 1], &device),
223 );
224
225 let _entry = metric.update(&input, &MetricMetadata::fake());
226 }
227}