burn_train/metric/
top_k_acc.rs

1use core::marker::PhantomData;
2use std::sync::Arc;
3
4use super::state::{FormatOptions, NumericMetricState};
5use super::{MetricEntry, MetricMetadata};
6use crate::metric::{Metric, MetricName, Numeric};
7use burn_core::tensor::backend::Backend;
8use burn_core::tensor::{ElementConversion, Int, Tensor};
9
10/// The Top-K accuracy metric.
11///
12/// For K=1, this is equivalent to the [accuracy metric](`super::acc::AccuracyMetric`).
13#[derive(Default, Clone)]
14pub struct TopKAccuracyMetric<B: Backend> {
15    name: Arc<String>,
16    k: usize,
17    state: NumericMetricState,
18    /// If specified, targets equal to this value will be considered padding and will not count
19    /// towards the metric
20    pad_token: Option<usize>,
21    _b: PhantomData<B>,
22}
23
24/// The [top-k accuracy metric](TopKAccuracyMetric) input type.
25#[derive(new)]
26pub struct TopKAccuracyInput<B: Backend> {
27    /// The outputs (batch_size, num_classes)
28    outputs: Tensor<B, 2>,
29    /// The labels (batch_size)
30    targets: Tensor<B, 1, Int>,
31}
32
33impl<B: Backend> TopKAccuracyMetric<B> {
34    /// Creates the metric.
35    pub fn new(k: usize) -> Self {
36        Self {
37            name: Arc::new(format!("Top-K Accuracy @ TopK({})", k)),
38            k,
39            ..Default::default()
40        }
41    }
42
43    /// Sets the pad token.
44    pub fn with_pad_token(mut self, index: usize) -> Self {
45        self.pad_token = Some(index);
46        self
47    }
48}
49
50impl<B: Backend> Metric for TopKAccuracyMetric<B> {
51    type Input = TopKAccuracyInput<B>;
52
53    fn update(&mut self, input: &TopKAccuracyInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
54        let [batch_size, _n_classes] = input.outputs.dims();
55
56        let targets = input.targets.clone().to_device(&B::Device::default());
57
58        let outputs = input
59            .outputs
60            .clone()
61            .argsort_descending(1)
62            .narrow(1, 0, self.k)
63            .to_device(&B::Device::default())
64            .reshape([batch_size, self.k]);
65
66        let (targets, num_pad) = match self.pad_token {
67            Some(pad_token) => {
68                // we ignore the samples where the target is equal to the pad token
69                let mask = targets.clone().equal_elem(pad_token as i64);
70                let num_pad = mask.clone().int().sum().into_scalar().elem::<f64>();
71                (targets.clone().mask_fill(mask, -1_i64), num_pad)
72            }
73            None => (targets.clone(), 0_f64),
74        };
75
76        let accuracy = targets
77            .reshape([batch_size, 1])
78            .repeat_dim(1, self.k)
79            .equal(outputs)
80            .int()
81            .sum()
82            .into_scalar()
83            .elem::<f64>()
84            / (batch_size as f64 - num_pad);
85
86        self.state.update(
87            100.0 * accuracy,
88            batch_size,
89            FormatOptions::new(self.name()).unit("%").precision(2),
90        )
91    }
92
93    fn clear(&mut self) {
94        self.state.reset()
95    }
96
97    fn name(&self) -> MetricName {
98        self.name.clone()
99    }
100}
101
102impl<B: Backend> Numeric for TopKAccuracyMetric<B> {
103    fn value(&self) -> super::NumericEntry {
104        self.state.value()
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::TestBackend;
112
113    #[test]
114    fn test_accuracy_without_padding() {
115        let device = Default::default();
116        let mut metric = TopKAccuracyMetric::<TestBackend>::new(2);
117        let input = TopKAccuracyInput::new(
118            Tensor::from_data(
119                [
120                    [0.0, 0.2, 0.8], // 2, 1
121                    [1.0, 2.0, 0.5], // 1, 0
122                    [0.4, 0.1, 0.2], // 0, 2
123                    [0.6, 0.7, 0.2], // 1, 0
124                ],
125                &device,
126            ),
127            Tensor::from_data([2, 2, 1, 1], &device),
128        );
129
130        let _entry = metric.update(&input, &MetricMetadata::fake());
131        assert_eq!(50.0, metric.value().current());
132    }
133
134    #[test]
135    fn test_accuracy_with_padding() {
136        let device = Default::default();
137        let mut metric = TopKAccuracyMetric::<TestBackend>::new(2).with_pad_token(3);
138        let input = TopKAccuracyInput::new(
139            Tensor::from_data(
140                [
141                    [0.0, 0.2, 0.8, 0.0], // 2, 1
142                    [1.0, 2.0, 0.5, 0.0], // 1, 0
143                    [0.4, 0.1, 0.2, 0.0], // 0, 2
144                    [0.6, 0.7, 0.2, 0.0], // 1, 0
145                    [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
146                    [0.0, 0.1, 0.2, 0.0], // Error on padding should not count
147                    [0.6, 0.0, 0.2, 0.0], // Error on padding should not count
148                ],
149                &device,
150            ),
151            Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
152        );
153
154        let _entry = metric.update(&input, &MetricMetadata::fake());
155        assert_eq!(50.0, metric.value().current());
156    }
157
158    #[test]
159    fn test_parameterized_unique_name() {
160        let metric_a = TopKAccuracyMetric::<TestBackend>::new(2);
161        let metric_b = TopKAccuracyMetric::<TestBackend>::new(1);
162        let metric_c = TopKAccuracyMetric::<TestBackend>::new(2);
163
164        assert_ne!(metric_a.name(), metric_b.name());
165        assert_eq!(metric_a.name(), metric_c.name());
166    }
167}