1use core::marker::PhantomData;
2
3use super::state::{FormatOptions, NumericMetricState};
4use super::{MetricEntry, MetricMetadata};
5use crate::metric::{Metric, MetricName, Numeric};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{ElementConversion, Int, Tensor};
8
9#[derive(Clone)]
11pub struct AccuracyMetric<B: Backend> {
12 name: MetricName,
13 state: NumericMetricState,
14 pad_token: Option<usize>,
15 _b: PhantomData<B>,
16}
17
18#[derive(new)]
20pub struct AccuracyInput<B: Backend> {
21 outputs: Tensor<B, 2>,
22 targets: Tensor<B, 1, Int>,
23}
24
25impl<B: Backend> Default for AccuracyMetric<B> {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl<B: Backend> AccuracyMetric<B> {
32 pub fn new() -> Self {
34 Self {
35 name: MetricName::new("Accuracy".to_string()),
36 state: Default::default(),
37 pad_token: Default::default(),
38 _b: PhantomData,
39 }
40 }
41
42 pub fn with_pad_token(mut self, index: usize) -> Self {
44 self.pad_token = Some(index);
45 self
46 }
47}
48
49impl<B: Backend> Metric for AccuracyMetric<B> {
50 type Input = AccuracyInput<B>;
51
52 fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
53 let targets = input.targets.clone();
54 let outputs = input.outputs.clone();
55
56 let [batch_size, _n_classes] = outputs.dims();
57
58 let outputs = outputs.argmax(1).reshape([batch_size]);
59
60 let accuracy = match self.pad_token {
61 Some(pad_token) => {
62 let mask = targets.clone().equal_elem(pad_token as i64);
63 let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);
64 let num_pad = mask.float().sum();
65
66 let acc = matches.sum() / (num_pad.neg() + batch_size as f32);
67
68 acc.into_scalar().elem::<f64>()
69 }
70 None => {
71 outputs
72 .equal(targets)
73 .int()
74 .sum()
75 .into_scalar()
76 .elem::<f64>()
77 / batch_size as f64
78 }
79 };
80
81 self.state.update(
82 100.0 * accuracy,
83 batch_size,
84 FormatOptions::new(self.name()).unit("%").precision(2),
85 )
86 }
87
88 fn clear(&mut self) {
89 self.state.reset()
90 }
91
92 fn name(&self) -> MetricName {
93 self.name.clone()
94 }
95}
96
97impl<B: Backend> Numeric for AccuracyMetric<B> {
98 fn value(&self) -> super::NumericEntry {
99 self.state.value()
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::TestBackend;
107
108 #[test]
109 fn test_accuracy_without_padding() {
110 let device = Default::default();
111 let mut metric = AccuracyMetric::<TestBackend>::new();
112 let input = AccuracyInput::new(
113 Tensor::from_data(
114 [
115 [0.0, 0.2, 0.8], [1.0, 2.0, 0.5], [0.4, 0.1, 0.2], [0.6, 0.7, 0.2], ],
120 &device,
121 ),
122 Tensor::from_data([2, 2, 1, 1], &device),
123 );
124
125 let _entry = metric.update(&input, &MetricMetadata::fake());
126 assert_eq!(50.0, metric.value().current());
127 }
128
129 #[test]
130 fn test_accuracy_with_padding() {
131 let device = Default::default();
132 let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);
133 let input = AccuracyInput::new(
134 Tensor::from_data(
135 [
136 [0.0, 0.2, 0.8, 0.0], [1.0, 2.0, 0.5, 0.0], [0.4, 0.1, 0.2, 0.0], [0.6, 0.7, 0.2, 0.0], [0.0, 0.1, 0.2, 5.0], [0.0, 0.1, 0.2, 0.0], [0.6, 0.0, 0.2, 0.0], ],
144 &device,
145 ),
146 Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
147 );
148
149 let _entry = metric.update(&input, &MetricMetadata::fake());
150 assert_eq!(50.0, metric.value().current());
151 }
152}