1use core::marker::PhantomData;
2
3use super::state::{FormatOptions, NumericMetricState};
4use super::{MetricEntry, MetricMetadata};
5use crate::metric::{Metric, Numeric};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9#[derive(Default)]
11pub struct AccuracyMetric<B: Backend> {
12 state: NumericMetricState,
13 pad_token: Option<usize>,
14 _b: PhantomData<B>,
15}
16
17#[derive(new)]
19pub struct AccuracyInput<B: Backend> {
20 outputs: Tensor<B, 2>,
21 targets: Tensor<B, 1, Int>,
22}
23
24impl<B: Backend> AccuracyMetric<B> {
25 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn with_pad_token(mut self, index: usize) -> Self {
32 self.pad_token = Some(index);
33 self
34 }
35}
36
37impl<B: Backend> Metric for AccuracyMetric<B> {
38 type Input = AccuracyInput<B>;
39
40 fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
41 let targets = input.targets.clone();
42 let outputs = input.outputs.clone();
43
44 let [batch_size, _n_classes] = outputs.dims();
45
46 let outputs = outputs.argmax(1).reshape([batch_size]);
47
48 let accuracy = match self.pad_token {
49 Some(pad_token) => {
50 let mask = targets.clone().equal_elem(pad_token as i64);
51 let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);
52 let num_pad = mask.float().sum();
53
54 let acc = matches.sum() / (num_pad.neg() + batch_size as f32);
55
56 acc.into_scalar().elem::<f64>()
57 }
58 None => {
59 outputs
60 .equal(targets)
61 .int()
62 .sum()
63 .into_scalar()
64 .elem::<f64>()
65 / batch_size as f64
66 }
67 };
68
69 self.state.update(
70 100.0 * accuracy,
71 batch_size,
72 FormatOptions::new(self.name()).unit("%").precision(2),
73 )
74 }
75
76 fn clear(&mut self) {
77 self.state.reset()
78 }
79
80 fn name(&self) -> String {
81 "Accuracy".to_string()
82 }
83}
84
85impl<B: Backend> Numeric for AccuracyMetric<B> {
86 fn value(&self) -> f64 {
87 self.state.value()
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::TestBackend;
95
96 #[test]
97 fn test_accuracy_without_padding() {
98 let device = Default::default();
99 let mut metric = AccuracyMetric::<TestBackend>::new();
100 let input = AccuracyInput::new(
101 Tensor::from_data(
102 [
103 [0.0, 0.2, 0.8], [1.0, 2.0, 0.5], [0.4, 0.1, 0.2], [0.6, 0.7, 0.2], ],
108 &device,
109 ),
110 Tensor::from_data([2, 2, 1, 1], &device),
111 );
112
113 let _entry = metric.update(&input, &MetricMetadata::fake());
114 assert_eq!(50.0, metric.value());
115 }
116
117 #[test]
118 fn test_accuracy_with_padding() {
119 let device = Default::default();
120 let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);
121 let input = AccuracyInput::new(
122 Tensor::from_data(
123 [
124 [0.0, 0.2, 0.8, 0.0], [1.0, 2.0, 0.5, 0.0], [0.4, 0.1, 0.2, 0.0], [0.6, 0.7, 0.2, 0.0], [0.0, 0.1, 0.2, 5.0], [0.0, 0.1, 0.2, 0.0], [0.6, 0.0, 0.2, 0.0], ],
132 &device,
133 ),
134 Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
135 );
136
137 let _entry = metric.update(&input, &MetricMetadata::fake());
138 assert_eq!(50.0, metric.value());
139 }
140}