Skip to main content

burn_train/metric/
acc.rs

1use core::marker::PhantomData;
2
3use super::MetricMetadata;
4use super::state::{FormatOptions, NumericMetricState};
5use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, SerializedEntry};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9/// The accuracy metric.
10#[derive(Clone)]
11pub struct AccuracyMetric<B: Backend> {
12    name: MetricName,
13    state: NumericMetricState,
14    pad_token: Option<usize>,
15    _b: PhantomData<B>,
16}
17
18/// The [accuracy metric](AccuracyMetric) input type.
19#[derive(new)]
20pub struct AccuracyInput<B: Backend> {
21    outputs: Tensor<B, 2>,
22    targets: Tensor<B, 1, Int>,
23}
24
25impl<B: Backend> Default for AccuracyMetric<B> {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl<B: Backend> AccuracyMetric<B> {
32    /// Creates the metric.
33    pub fn new() -> Self {
34        Self {
35            name: MetricName::new("Accuracy".to_string()),
36            state: Default::default(),
37            pad_token: Default::default(),
38            _b: PhantomData,
39        }
40    }
41
42    /// Sets the pad token.
43    pub fn with_pad_token(mut self, index: usize) -> Self {
44        self.pad_token = Some(index);
45        self
46    }
47}
48
49impl<B: Backend> Metric for AccuracyMetric<B> {
50    type Input = AccuracyInput<B>;
51
52    fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
53        let targets = input.targets.clone();
54        let outputs = input.outputs.clone();
55
56        let [batch_size, _n_classes] = outputs.dims();
57
58        let outputs = outputs.argmax(1).reshape([batch_size]);
59
60        let accuracy = match self.pad_token {
61            Some(pad_token) => {
62                let mask = targets.clone().equal_elem(pad_token as i64);
63                let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);
64                let num_pad = mask.float().sum();
65
66                let acc = matches.sum() / (num_pad.neg() + batch_size as f32);
67
68                acc.into_scalar().elem::<f64>()
69            }
70            None => {
71                outputs
72                    .equal(targets)
73                    .int()
74                    .sum()
75                    .into_scalar()
76                    .elem::<f64>()
77                    / batch_size as f64
78            }
79        };
80
81        self.state.update(
82            100.0 * accuracy,
83            batch_size,
84            FormatOptions::new(self.name()).unit("%").precision(2),
85        )
86    }
87
88    fn clear(&mut self) {
89        self.state.reset()
90    }
91
92    fn name(&self) -> MetricName {
93        self.name.clone()
94    }
95
96    fn attributes(&self) -> MetricAttributes {
97        super::NumericAttributes {
98            unit: Some("%".to_string()),
99            higher_is_better: true,
100        }
101        .into()
102    }
103}
104
105impl<B: Backend> Numeric for AccuracyMetric<B> {
106    fn value(&self) -> super::NumericEntry {
107        self.state.current_value()
108    }
109
110    fn running_value(&self) -> super::NumericEntry {
111        self.state.running_value()
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::TestBackend;
119
120    #[test]
121    fn test_accuracy_without_padding() {
122        let device = Default::default();
123        let mut metric = AccuracyMetric::<TestBackend>::new();
124        let input = AccuracyInput::new(
125            Tensor::from_data(
126                [
127                    [0.0, 0.2, 0.8], // 2
128                    [1.0, 2.0, 0.5], // 1
129                    [0.4, 0.1, 0.2], // 0
130                    [0.6, 0.7, 0.2], // 1
131                ],
132                &device,
133            ),
134            Tensor::from_data([2, 2, 1, 1], &device),
135        );
136
137        let _entry = metric.update(&input, &MetricMetadata::fake());
138        assert_eq!(50.0, metric.value().current());
139    }
140
141    #[test]
142    fn test_accuracy_with_padding() {
143        let device = Default::default();
144        let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);
145        let input = AccuracyInput::new(
146            Tensor::from_data(
147                [
148                    [0.0, 0.2, 0.8, 0.0], // 2
149                    [1.0, 2.0, 0.5, 0.0], // 1
150                    [0.4, 0.1, 0.2, 0.0], // 0
151                    [0.6, 0.7, 0.2, 0.0], // 1
152                    [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
153                    [0.0, 0.1, 0.2, 0.0], // Error on padding should not count
154                    [0.6, 0.0, 0.2, 0.0], // Error on padding should not count
155                ],
156                &device,
157            ),
158            Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
159        );
160
161        let _entry = metric.update(&input, &MetricMetadata::fake());
162        assert_eq!(50.0, metric.value().current());
163    }
164}