burn_train/metric/
auroc.rs

1use core::f64;
2use core::marker::PhantomData;
3
4use super::state::{FormatOptions, NumericMetricState};
5use super::{MetricEntry, MetricMetadata};
6use crate::metric::{Metric, MetricName, Numeric};
7use burn_core::tensor::backend::Backend;
8use burn_core::tensor::{ElementConversion, Int, Tensor};
9
10/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)) for binary classification.
11#[derive(Clone)]
12pub struct AurocMetric<B: Backend> {
13    name: MetricName,
14    state: NumericMetricState,
15    _b: PhantomData<B>,
16}
17
18/// The [AUROC metric](AurocMetric) input type.
19#[derive(new)]
20pub struct AurocInput<B: Backend> {
21    outputs: Tensor<B, 2>,
22    targets: Tensor<B, 1, Int>,
23}
24
25impl<B: Backend> Default for AurocMetric<B> {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl<B: Backend> AurocMetric<B> {
32    /// Creates the metric.
33    pub fn new() -> Self {
34        Self {
35            name: MetricName::new("AUROC".to_string()),
36            state: Default::default(),
37            _b: PhantomData,
38        }
39    }
40
41    fn binary_auroc(&self, probabilities: &Tensor<B, 1>, targets: &Tensor<B, 1, Int>) -> f64 {
42        let n = targets.dims()[0];
43
44        let n_pos = targets.clone().sum().into_scalar().elem::<u64>() as usize;
45
46        // Early return if we don't have both positive and negative samples
47        if n_pos == 0 || n_pos == n {
48            if n_pos == 0 {
49                log::warn!("Metric cannot be computed because all target values are negative.")
50            } else {
51                log::warn!("Metric cannot be computed because all target values are positive.")
52            }
53            return 0.0;
54        }
55
56        let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]);
57        let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]);
58
59        let valid_pairs = pos_mask * neg_mask;
60
61        let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n);
62        let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n);
63
64        let correct_order = prob_i.clone().greater(prob_j.clone()).int();
65
66        let ties = prob_i.equal(prob_j).int();
67
68        // Calculate AUC components
69        let num_pairs = valid_pairs.clone().sum().into_scalar().elem::<f64>();
70        let correct_pairs = (correct_order * valid_pairs.clone())
71            .sum()
72            .into_scalar()
73            .elem::<f64>();
74        let tied_pairs = (ties * valid_pairs).sum().into_scalar().elem::<f64>();
75
76        (correct_pairs + 0.5 * tied_pairs) / num_pairs
77    }
78}
79
80impl<B: Backend> Metric for AurocMetric<B> {
81    type Input = AurocInput<B>;
82
83    fn update(&mut self, input: &AurocInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
84        let [batch_size, num_classes] = input.outputs.dims();
85
86        assert_eq!(
87            num_classes, 2,
88            "Currently only binary classification is supported"
89        );
90
91        let probabilities = {
92            let exponents = input.outputs.clone().exp();
93            let sum = exponents.clone().sum_dim(1);
94            (exponents / sum)
95                .select(1, Tensor::arange(1..2, &input.outputs.device()))
96                .squeeze_dim(1)
97        };
98
99        let area_under_curve = self.binary_auroc(&probabilities, &input.targets);
100
101        self.state.update(
102            100.0 * area_under_curve,
103            batch_size,
104            FormatOptions::new(self.name()).unit("%").precision(2),
105        )
106    }
107
108    fn clear(&mut self) {
109        self.state.reset()
110    }
111
112    fn name(&self) -> MetricName {
113        self.name.clone()
114    }
115}
116
117impl<B: Backend> Numeric for AurocMetric<B> {
118    fn value(&self) -> super::NumericEntry {
119        self.state.value()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::TestBackend;
127
128    #[test]
129    fn test_auroc() {
130        let device = Default::default();
131        let mut metric = AurocMetric::<TestBackend>::new();
132
133        let input = AurocInput::new(
134            Tensor::from_data(
135                [
136                    [0.1, 0.9], // High confidence positive
137                    [0.7, 0.3], // Low confidence negative
138                    [0.6, 0.4], // Low confidence negative
139                    [0.2, 0.8], // High confidence positive
140                ],
141                &device,
142            ),
143            Tensor::from_data([1, 0, 0, 1], &device), // True labels
144        );
145
146        let _entry = metric.update(&input, &MetricMetadata::fake());
147        assert_eq!(metric.value().current(), 100.0);
148    }
149
150    #[test]
151    fn test_auroc_perfect_separation() {
152        let device = Default::default();
153        let mut metric = AurocMetric::<TestBackend>::new();
154
155        let input = AurocInput::new(
156            Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
157            Tensor::from_data([1, 0, 0, 1], &device),
158        );
159
160        let _entry = metric.update(&input, &MetricMetadata::fake());
161        assert_eq!(metric.value().current(), 100.0); // Perfect AUC
162    }
163
164    #[test]
165    fn test_auroc_random() {
166        let device = Default::default();
167        let mut metric = AurocMetric::<TestBackend>::new();
168
169        let input = AurocInput::new(
170            Tensor::from_data(
171                [
172                    [0.5, 0.5], // Random predictions
173                    [0.5, 0.5],
174                    [0.5, 0.5],
175                    [0.5, 0.5],
176                ],
177                &device,
178            ),
179            Tensor::from_data([1, 0, 0, 1], &device),
180        );
181
182        let _entry = metric.update(&input, &MetricMetadata::fake());
183        assert_eq!(metric.value().current(), 50.0);
184    }
185
186    #[test]
187    fn test_auroc_all_one_class() {
188        let device = Default::default();
189        let mut metric = AurocMetric::<TestBackend>::new();
190
191        let input = AurocInput::new(
192            Tensor::from_data(
193                [
194                    [0.1, 0.9], // All positives predictions
195                    [0.2, 0.8],
196                    [0.3, 0.7],
197                    [0.4, 0.6],
198                ],
199                &device,
200            ),
201            Tensor::from_data([1, 1, 1, 1], &device), // All positive class
202        );
203
204        let _entry = metric.update(&input, &MetricMetadata::fake());
205        assert_eq!(metric.value().current(), 0.0);
206    }
207
208    #[test]
209    #[should_panic(expected = "Currently only binary classification is supported")]
210    fn test_auroc_multiclass_error() {
211        let device = Default::default();
212        let mut metric = AurocMetric::<TestBackend>::new();
213
214        let input = AurocInput::new(
215            Tensor::from_data(
216                [
217                    [0.1, 0.2, 0.7], // More than 2 classes not supported
218                    [0.3, 0.5, 0.2],
219                ],
220                &device,
221            ),
222            Tensor::from_data([2, 1], &device),
223        );
224
225        let _entry = metric.update(&input, &MetricMetadata::fake());
226    }
227}