1use core::marker::PhantomData;
2use std::sync::Arc;
3
4use super::state::{FormatOptions, NumericMetricState};
5use super::{MetricEntry, MetricMetadata};
6use crate::metric::{Metric, MetricName, Numeric};
7use burn_core::tensor::backend::Backend;
8use burn_core::tensor::{ElementConversion, Int, Tensor};
9
10#[derive(Default, Clone)]
14pub struct TopKAccuracyMetric<B: Backend> {
15 name: Arc<String>,
16 k: usize,
17 state: NumericMetricState,
18 pad_token: Option<usize>,
21 _b: PhantomData<B>,
22}
23
24#[derive(new)]
26pub struct TopKAccuracyInput<B: Backend> {
27 outputs: Tensor<B, 2>,
29 targets: Tensor<B, 1, Int>,
31}
32
33impl<B: Backend> TopKAccuracyMetric<B> {
34 pub fn new(k: usize) -> Self {
36 Self {
37 name: Arc::new(format!("Top-K Accuracy @ TopK({})", k)),
38 k,
39 ..Default::default()
40 }
41 }
42
43 pub fn with_pad_token(mut self, index: usize) -> Self {
45 self.pad_token = Some(index);
46 self
47 }
48}
49
50impl<B: Backend> Metric for TopKAccuracyMetric<B> {
51 type Input = TopKAccuracyInput<B>;
52
53 fn update(&mut self, input: &TopKAccuracyInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
54 let [batch_size, _n_classes] = input.outputs.dims();
55
56 let targets = input.targets.clone().to_device(&B::Device::default());
57
58 let outputs = input
59 .outputs
60 .clone()
61 .argsort_descending(1)
62 .narrow(1, 0, self.k)
63 .to_device(&B::Device::default())
64 .reshape([batch_size, self.k]);
65
66 let (targets, num_pad) = match self.pad_token {
67 Some(pad_token) => {
68 let mask = targets.clone().equal_elem(pad_token as i64);
70 let num_pad = mask.clone().int().sum().into_scalar().elem::<f64>();
71 (targets.clone().mask_fill(mask, -1_i64), num_pad)
72 }
73 None => (targets.clone(), 0_f64),
74 };
75
76 let accuracy = targets
77 .reshape([batch_size, 1])
78 .repeat_dim(1, self.k)
79 .equal(outputs)
80 .int()
81 .sum()
82 .into_scalar()
83 .elem::<f64>()
84 / (batch_size as f64 - num_pad);
85
86 self.state.update(
87 100.0 * accuracy,
88 batch_size,
89 FormatOptions::new(self.name()).unit("%").precision(2),
90 )
91 }
92
93 fn clear(&mut self) {
94 self.state.reset()
95 }
96
97 fn name(&self) -> MetricName {
98 self.name.clone()
99 }
100}
101
102impl<B: Backend> Numeric for TopKAccuracyMetric<B> {
103 fn value(&self) -> super::NumericEntry {
104 self.state.value()
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::TestBackend;
112
113 #[test]
114 fn test_accuracy_without_padding() {
115 let device = Default::default();
116 let mut metric = TopKAccuracyMetric::<TestBackend>::new(2);
117 let input = TopKAccuracyInput::new(
118 Tensor::from_data(
119 [
120 [0.0, 0.2, 0.8], [1.0, 2.0, 0.5], [0.4, 0.1, 0.2], [0.6, 0.7, 0.2], ],
125 &device,
126 ),
127 Tensor::from_data([2, 2, 1, 1], &device),
128 );
129
130 let _entry = metric.update(&input, &MetricMetadata::fake());
131 assert_eq!(50.0, metric.value().current());
132 }
133
134 #[test]
135 fn test_accuracy_with_padding() {
136 let device = Default::default();
137 let mut metric = TopKAccuracyMetric::<TestBackend>::new(2).with_pad_token(3);
138 let input = TopKAccuracyInput::new(
139 Tensor::from_data(
140 [
141 [0.0, 0.2, 0.8, 0.0], [1.0, 2.0, 0.5, 0.0], [0.4, 0.1, 0.2, 0.0], [0.6, 0.7, 0.2, 0.0], [0.0, 0.1, 0.2, 5.0], [0.0, 0.1, 0.2, 0.0], [0.6, 0.0, 0.2, 0.0], ],
149 &device,
150 ),
151 Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
152 );
153
154 let _entry = metric.update(&input, &MetricMetadata::fake());
155 assert_eq!(50.0, metric.value().current());
156 }
157
158 #[test]
159 fn test_parameterized_unique_name() {
160 let metric_a = TopKAccuracyMetric::<TestBackend>::new(2);
161 let metric_b = TopKAccuracyMetric::<TestBackend>::new(1);
162 let metric_c = TopKAccuracyMetric::<TestBackend>::new(2);
163
164 assert_ne!(metric_a.name(), metric_b.name());
165 assert_eq!(metric_a.name(), metric_c.name());
166 }
167}