1use core::marker::PhantomData;
2use std::sync::Arc;
3
4use super::state::{FormatOptions, NumericMetricState};
5use super::{MetricEntry, MetricMetadata};
6use crate::metric::{Metric, MetricName, Numeric};
7use burn_core::tensor::{ElementConversion, Int, Tensor, activation::sigmoid, backend::Backend};
8
9#[derive(Clone)]
11pub struct HammingScore<B: Backend> {
12 name: MetricName,
13 state: NumericMetricState,
14 threshold: f32,
15 sigmoid: bool,
16 _b: PhantomData<B>,
17}
18
19#[derive(new)]
21pub struct HammingScoreInput<B: Backend> {
22 outputs: Tensor<B, 2>,
23 targets: Tensor<B, 2, Int>,
24}
25
26impl<B: Backend> HammingScore<B> {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 fn update_name(&mut self) {
33 self.name = Arc::new(format!("Hamming Score @ Threshold({})", self.threshold));
34 }
35
36 pub fn with_threshold(mut self, threshold: f32) -> Self {
38 self.threshold = threshold;
39 self.update_name();
40 self
41 }
42
43 pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
45 self.sigmoid = sigmoid;
46 self.update_name();
47 self
48 }
49}
50
51impl<B: Backend> Default for HammingScore<B> {
52 fn default() -> Self {
54 let threshold = 0.5;
55 let name = Arc::new(format!("Hamming Score @ Threshold({})", threshold));
56
57 Self {
58 name,
59 state: NumericMetricState::default(),
60 threshold,
61 sigmoid: false,
62 _b: PhantomData,
63 }
64 }
65}
66
67impl<B: Backend> Metric for HammingScore<B> {
68 type Input = HammingScoreInput<B>;
69
70 fn update(&mut self, input: &HammingScoreInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
71 let [batch_size, _n_classes] = input.outputs.dims();
72
73 let targets = input.targets.clone();
74
75 let mut outputs = input.outputs.clone();
76
77 if self.sigmoid {
78 outputs = sigmoid(outputs);
79 }
80
81 let score = outputs
82 .greater_elem(self.threshold)
83 .equal(targets.bool())
84 .float()
85 .mean()
86 .into_scalar()
87 .elem::<f64>();
88
89 self.state.update(
90 100.0 * score,
91 batch_size,
92 FormatOptions::new(self.name()).unit("%").precision(2),
93 )
94 }
95
96 fn clear(&mut self) {
97 self.state.reset()
98 }
99
100 fn name(&self) -> MetricName {
101 self.name.clone()
102 }
103}
104
105impl<B: Backend> Numeric for HammingScore<B> {
106 fn value(&self) -> super::NumericEntry {
107 self.state.value()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::TestBackend;
115
116 #[test]
117 fn test_hamming_score() {
118 let device = Default::default();
119 let mut metric = HammingScore::<TestBackend>::new();
120
121 let x = Tensor::from_data(
122 [
123 [0.32, 0.52, 0.38, 0.68, 0.61], [0.43, 0.31, 0.21, 0.63, 0.53], [0.44, 0.25, 0.71, 0.39, 0.73], [0.49, 0.37, 0.68, 0.39, 0.31], ],
128 &device,
129 );
130 let y = Tensor::from_data(
131 [
132 [0, 1, 0, 1, 1],
133 [0, 0, 0, 1, 1],
134 [0, 0, 1, 0, 1],
135 [0, 0, 1, 0, 0],
136 ],
137 &device,
138 );
139
140 let _entry = metric.update(
141 &HammingScoreInput::new(x.clone(), y.clone()),
142 &MetricMetadata::fake(),
143 );
144 assert_eq!(100.0, metric.value().current());
145
146 let y = y.neg().add_scalar(1);
148 let _entry = metric.update(
149 &HammingScoreInput::new(x.clone(), y), &MetricMetadata::fake(),
151 );
152 assert_eq!(0.0, metric.value().current());
153
154 let y = Tensor::from_data(
156 [
157 [0, 1, 1, 0, 1],
158 [0, 0, 0, 0, 1],
159 [0, 0, 0, 0, 1],
160 [0, 1, 1, 0, 0],
161 ],
162 &device,
163 );
164 let _entry = metric.update(
165 &HammingScoreInput::new(x, y), &MetricMetadata::fake(),
167 );
168 assert_eq!(75.0, metric.value().current());
169 }
170
171 #[test]
172 fn test_parameterized_unique_name() {
173 let metric_a = HammingScore::<TestBackend>::new().with_threshold(0.5);
174 let metric_b = HammingScore::<TestBackend>::new().with_threshold(0.75);
175 let metric_c = HammingScore::<TestBackend>::new().with_threshold(0.5);
176
177 assert_ne!(metric_a.name(), metric_b.name());
178 assert_eq!(metric_a.name(), metric_c.name());
179 }
180}