1use core::marker::PhantomData;
2
3use super::MetricMetadata;
4use super::state::{FormatOptions, NumericMetricState};
5use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, SerializedEntry};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9#[derive(Clone)]
11pub struct AccuracyMetric<B: Backend> {
12 name: MetricName,
13 state: NumericMetricState,
14 pad_token: Option<usize>,
15 _b: PhantomData<B>,
16}
17
18#[derive(new)]
20pub struct AccuracyInput<B: Backend> {
21 outputs: Tensor<B, 2>,
22 targets: Tensor<B, 1, Int>,
23}
24
25impl<B: Backend> Default for AccuracyMetric<B> {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl<B: Backend> AccuracyMetric<B> {
32 pub fn new() -> Self {
34 Self {
35 name: MetricName::new("Accuracy".to_string()),
36 state: Default::default(),
37 pad_token: Default::default(),
38 _b: PhantomData,
39 }
40 }
41
42 pub fn with_pad_token(mut self, index: usize) -> Self {
44 self.pad_token = Some(index);
45 self
46 }
47}
48
49impl<B: Backend> Metric for AccuracyMetric<B> {
50 type Input = AccuracyInput<B>;
51
52 fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
53 let targets = input.targets.clone();
54 let outputs = input.outputs.clone();
55
56 let [batch_size, _n_classes] = outputs.dims();
57
58 let outputs = outputs.argmax(1).reshape([batch_size]);
59
60 let accuracy = match self.pad_token {
61 Some(pad_token) => {
62 let mask = targets.clone().equal_elem(pad_token as i64);
63 let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);
64 let num_pad = mask.float().sum();
65
66 let acc = matches.sum() / (num_pad.neg() + batch_size as f32);
67
68 acc.into_scalar().elem::<f64>()
69 }
70 None => {
71 outputs
72 .equal(targets)
73 .int()
74 .sum()
75 .into_scalar()
76 .elem::<f64>()
77 / batch_size as f64
78 }
79 };
80
81 self.state.update(
82 100.0 * accuracy,
83 batch_size,
84 FormatOptions::new(self.name()).unit("%").precision(2),
85 )
86 }
87
88 fn clear(&mut self) {
89 self.state.reset()
90 }
91
92 fn name(&self) -> MetricName {
93 self.name.clone()
94 }
95
96 fn attributes(&self) -> MetricAttributes {
97 super::NumericAttributes {
98 unit: Some("%".to_string()),
99 higher_is_better: true,
100 }
101 .into()
102 }
103}
104
105impl<B: Backend> Numeric for AccuracyMetric<B> {
106 fn value(&self) -> super::NumericEntry {
107 self.state.current_value()
108 }
109
110 fn running_value(&self) -> super::NumericEntry {
111 self.state.running_value()
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::TestBackend;
119
120 #[test]
121 fn test_accuracy_without_padding() {
122 let device = Default::default();
123 let mut metric = AccuracyMetric::<TestBackend>::new();
124 let input = AccuracyInput::new(
125 Tensor::from_data(
126 [
127 [0.0, 0.2, 0.8], [1.0, 2.0, 0.5], [0.4, 0.1, 0.2], [0.6, 0.7, 0.2], ],
132 &device,
133 ),
134 Tensor::from_data([2, 2, 1, 1], &device),
135 );
136
137 let _entry = metric.update(&input, &MetricMetadata::fake());
138 assert_eq!(50.0, metric.value().current());
139 }
140
141 #[test]
142 fn test_accuracy_with_padding() {
143 let device = Default::default();
144 let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);
145 let input = AccuracyInput::new(
146 Tensor::from_data(
147 [
148 [0.0, 0.2, 0.8, 0.0], [1.0, 2.0, 0.5, 0.0], [0.4, 0.1, 0.2, 0.0], [0.6, 0.7, 0.2, 0.0], [0.0, 0.1, 0.2, 5.0], [0.0, 0.1, 0.2, 0.0], [0.6, 0.0, 0.2, 0.0], ],
156 &device,
157 ),
158 Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
159 );
160
161 let _entry = metric.update(&input, &MetricMetadata::fake());
162 assert_eq!(50.0, metric.value().current());
163 }
164}