1use core::marker::PhantomData;
2use std::sync::Arc;
3
4use super::state::{FormatOptions, NumericMetricState};
5use super::{MetricMetadata, SerializedEntry};
6use crate::metric::{
7 Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
8};
9use burn_core::tensor::{ElementConversion, Int, Tensor, activation::sigmoid, backend::Backend};
10
11#[derive(Clone)]
13pub struct HammingScore<B: Backend> {
14 name: MetricName,
15 state: NumericMetricState,
16 threshold: f32,
17 sigmoid: bool,
18 _b: PhantomData<B>,
19}
20
21#[derive(new)]
23pub struct HammingScoreInput<B: Backend> {
24 outputs: Tensor<B, 2>,
25 targets: Tensor<B, 2, Int>,
26}
27
28impl<B: Backend> HammingScore<B> {
29 pub fn new() -> Self {
31 Self::default()
32 }
33
34 fn update_name(&mut self) {
35 self.name = Arc::new(format!("Hamming Score @ Threshold({})", self.threshold));
36 }
37
38 pub fn with_threshold(mut self, threshold: f32) -> Self {
40 self.threshold = threshold;
41 self.update_name();
42 self
43 }
44
45 pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
47 self.sigmoid = sigmoid;
48 self.update_name();
49 self
50 }
51}
52
53impl<B: Backend> Default for HammingScore<B> {
54 fn default() -> Self {
56 let threshold = 0.5;
57 let name = Arc::new(format!("Hamming Score @ Threshold({})", threshold));
58
59 Self {
60 name,
61 state: NumericMetricState::default(),
62 threshold,
63 sigmoid: false,
64 _b: PhantomData,
65 }
66 }
67}
68
69impl<B: Backend> Metric for HammingScore<B> {
70 type Input = HammingScoreInput<B>;
71
72 fn update(
73 &mut self,
74 input: &HammingScoreInput<B>,
75 _metadata: &MetricMetadata,
76 ) -> SerializedEntry {
77 let [batch_size, _n_classes] = input.outputs.dims();
78
79 let targets = input.targets.clone();
80
81 let mut outputs = input.outputs.clone();
82
83 if self.sigmoid {
84 outputs = sigmoid(outputs);
85 }
86
87 let score = outputs
88 .greater_elem(self.threshold)
89 .equal(targets.bool())
90 .float()
91 .mean()
92 .into_scalar()
93 .elem::<f64>();
94
95 self.state.update(
96 100.0 * score,
97 batch_size,
98 FormatOptions::new(self.name()).unit("%").precision(2),
99 )
100 }
101
102 fn clear(&mut self) {
103 self.state.reset()
104 }
105
106 fn name(&self) -> MetricName {
107 self.name.clone()
108 }
109
110 fn attributes(&self) -> MetricAttributes {
111 NumericAttributes {
112 unit: Some("%".to_string()),
113 higher_is_better: true,
114 }
115 .into()
116 }
117}
118
119impl<B: Backend> Numeric for HammingScore<B> {
120 fn value(&self) -> NumericEntry {
121 self.state.current_value()
122 }
123
124 fn running_value(&self) -> NumericEntry {
125 self.state.running_value()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::TestBackend;
133
134 #[test]
135 fn test_hamming_score() {
136 let device = Default::default();
137 let mut metric = HammingScore::<TestBackend>::new();
138
139 let x = Tensor::from_data(
140 [
141 [0.32, 0.52, 0.38, 0.68, 0.61], [0.43, 0.31, 0.21, 0.63, 0.53], [0.44, 0.25, 0.71, 0.39, 0.73], [0.49, 0.37, 0.68, 0.39, 0.31], ],
146 &device,
147 );
148 let y = Tensor::from_data(
149 [
150 [0, 1, 0, 1, 1],
151 [0, 0, 0, 1, 1],
152 [0, 0, 1, 0, 1],
153 [0, 0, 1, 0, 0],
154 ],
155 &device,
156 );
157
158 let _entry = metric.update(
159 &HammingScoreInput::new(x.clone(), y.clone()),
160 &MetricMetadata::fake(),
161 );
162 assert_eq!(100.0, metric.value().current());
163
164 let y = y.neg().add_scalar(1);
166 let _entry = metric.update(
167 &HammingScoreInput::new(x.clone(), y), &MetricMetadata::fake(),
169 );
170 assert_eq!(0.0, metric.value().current());
171
172 let y = Tensor::from_data(
174 [
175 [0, 1, 1, 0, 1],
176 [0, 0, 0, 0, 1],
177 [0, 0, 0, 0, 1],
178 [0, 1, 1, 0, 0],
179 ],
180 &device,
181 );
182 let _entry = metric.update(
183 &HammingScoreInput::new(x, y), &MetricMetadata::fake(),
185 );
186 assert_eq!(75.0, metric.value().current());
187 }
188
189 #[test]
190 fn test_parameterized_unique_name() {
191 let metric_a = HammingScore::<TestBackend>::new().with_threshold(0.5);
192 let metric_b = HammingScore::<TestBackend>::new().with_threshold(0.75);
193 let metric_c = HammingScore::<TestBackend>::new().with_threshold(0.5);
194
195 assert_ne!(metric_a.name(), metric_b.name());
196 assert_eq!(metric_a.name(), metric_c.name());
197 }
198}