burn_train/metric/
acc.rs

1use core::marker::PhantomData;
2
3use super::state::{FormatOptions, NumericMetricState};
4use super::{MetricEntry, MetricMetadata};
5use crate::metric::{Metric, Numeric};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9/// The accuracy metric.
10#[derive(Default)]
11pub struct AccuracyMetric<B: Backend> {
12    state: NumericMetricState,
13    pad_token: Option<usize>,
14    _b: PhantomData<B>,
15}
16
17/// The [accuracy metric](AccuracyMetric) input type.
18#[derive(new)]
19pub struct AccuracyInput<B: Backend> {
20    outputs: Tensor<B, 2>,
21    targets: Tensor<B, 1, Int>,
22}
23
24impl<B: Backend> AccuracyMetric<B> {
25    /// Creates the metric.
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    /// Sets the pad token.
31    pub fn with_pad_token(mut self, index: usize) -> Self {
32        self.pad_token = Some(index);
33        self
34    }
35}
36
37impl<B: Backend> Metric for AccuracyMetric<B> {
38    type Input = AccuracyInput<B>;
39
40    fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
41        let targets = input.targets.clone();
42        let outputs = input.outputs.clone();
43
44        let [batch_size, _n_classes] = outputs.dims();
45
46        let outputs = outputs.argmax(1).reshape([batch_size]);
47
48        let accuracy = match self.pad_token {
49            Some(pad_token) => {
50                let mask = targets.clone().equal_elem(pad_token as i64);
51                let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);
52                let num_pad = mask.float().sum();
53
54                let acc = matches.sum() / (num_pad.neg() + batch_size as f32);
55
56                acc.into_scalar().elem::<f64>()
57            }
58            None => {
59                outputs
60                    .equal(targets)
61                    .int()
62                    .sum()
63                    .into_scalar()
64                    .elem::<f64>()
65                    / batch_size as f64
66            }
67        };
68
69        self.state.update(
70            100.0 * accuracy,
71            batch_size,
72            FormatOptions::new(self.name()).unit("%").precision(2),
73        )
74    }
75
76    fn clear(&mut self) {
77        self.state.reset()
78    }
79
80    fn name(&self) -> String {
81        "Accuracy".to_string()
82    }
83}
84
85impl<B: Backend> Numeric for AccuracyMetric<B> {
86    fn value(&self) -> f64 {
87        self.state.value()
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::TestBackend;
95
96    #[test]
97    fn test_accuracy_without_padding() {
98        let device = Default::default();
99        let mut metric = AccuracyMetric::<TestBackend>::new();
100        let input = AccuracyInput::new(
101            Tensor::from_data(
102                [
103                    [0.0, 0.2, 0.8], // 2
104                    [1.0, 2.0, 0.5], // 1
105                    [0.4, 0.1, 0.2], // 0
106                    [0.6, 0.7, 0.2], // 1
107                ],
108                &device,
109            ),
110            Tensor::from_data([2, 2, 1, 1], &device),
111        );
112
113        let _entry = metric.update(&input, &MetricMetadata::fake());
114        assert_eq!(50.0, metric.value());
115    }
116
117    #[test]
118    fn test_accuracy_with_padding() {
119        let device = Default::default();
120        let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);
121        let input = AccuracyInput::new(
122            Tensor::from_data(
123                [
124                    [0.0, 0.2, 0.8, 0.0], // 2
125                    [1.0, 2.0, 0.5, 0.0], // 1
126                    [0.4, 0.1, 0.2, 0.0], // 0
127                    [0.6, 0.7, 0.2, 0.0], // 1
128                    [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
129                    [0.0, 0.1, 0.2, 0.0], // Error on padding should not count
130                    [0.6, 0.0, 0.2, 0.0], // Error on padding should not count
131                ],
132                &device,
133            ),
134            Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
135        );
136
137        let _entry = metric.update(&input, &MetricMetadata::fake());
138        assert_eq!(50.0, metric.value());
139    }
140}