1use core::marker::PhantomData;
2
3use super::state::{FormatOptions, NumericMetricState};
4use super::{MetricEntry, MetricMetadata};
5use crate::metric::{Metric, Numeric};
6use burn_core::tensor::{activation::sigmoid, backend::Backend, ElementConversion, Int, Tensor};
7
8pub struct HammingScore<B: Backend> {
10 state: NumericMetricState,
11 threshold: f32,
12 sigmoid: bool,
13 _b: PhantomData<B>,
14}
15
16#[derive(new)]
18pub struct HammingScoreInput<B: Backend> {
19 outputs: Tensor<B, 2>,
20 targets: Tensor<B, 2, Int>,
21}
22
23impl<B: Backend> HammingScore<B> {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn with_threshold(mut self, threshold: f32) -> Self {
31 self.threshold = threshold;
32 self
33 }
34
35 pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
37 self.sigmoid = sigmoid;
38 self
39 }
40}
41
42impl<B: Backend> Default for HammingScore<B> {
43 fn default() -> Self {
45 Self {
46 state: NumericMetricState::default(),
47 threshold: 0.5,
48 sigmoid: false,
49 _b: PhantomData,
50 }
51 }
52}
53
54impl<B: Backend> Metric for HammingScore<B> {
55 const NAME: &'static str = "Hamming Score";
56
57 type Input = HammingScoreInput<B>;
58
59 fn update(&mut self, input: &HammingScoreInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
60 let [batch_size, _n_classes] = input.outputs.dims();
61
62 let targets = input.targets.clone();
63
64 let mut outputs = input.outputs.clone();
65
66 if self.sigmoid {
67 outputs = sigmoid(outputs);
68 }
69
70 let score = outputs
71 .greater_elem(self.threshold)
72 .equal(targets.bool())
73 .float()
74 .mean()
75 .into_scalar()
76 .elem::<f64>();
77
78 self.state.update(
79 100.0 * score,
80 batch_size,
81 FormatOptions::new(Self::NAME).unit("%").precision(2),
82 )
83 }
84
85 fn clear(&mut self) {
86 self.state.reset()
87 }
88}
89
90impl<B: Backend> Numeric for HammingScore<B> {
91 fn value(&self) -> f64 {
92 self.state.value()
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::TestBackend;
100
101 #[test]
102 fn test_hamming_score() {
103 let device = Default::default();
104 let mut metric = HammingScore::<TestBackend>::new();
105
106 let x = Tensor::from_data(
107 [
108 [0.32, 0.52, 0.38, 0.68, 0.61], [0.43, 0.31, 0.21, 0.63, 0.53], [0.44, 0.25, 0.71, 0.39, 0.73], [0.49, 0.37, 0.68, 0.39, 0.31], ],
113 &device,
114 );
115 let y = Tensor::from_data(
116 [
117 [0, 1, 0, 1, 1],
118 [0, 0, 0, 1, 1],
119 [0, 0, 1, 0, 1],
120 [0, 0, 1, 0, 0],
121 ],
122 &device,
123 );
124
125 let _entry = metric.update(
126 &HammingScoreInput::new(x.clone(), y.clone()),
127 &MetricMetadata::fake(),
128 );
129 assert_eq!(100.0, metric.value());
130
131 let y = y.neg().add_scalar(1);
133 let _entry = metric.update(
134 &HammingScoreInput::new(x.clone(), y), &MetricMetadata::fake(),
136 );
137 assert_eq!(0.0, metric.value());
138
139 let y = Tensor::from_data(
141 [
142 [0, 1, 1, 0, 1],
143 [0, 0, 0, 0, 1],
144 [0, 0, 0, 0, 1],
145 [0, 1, 1, 0, 0],
146 ],
147 &device,
148 );
149 let _entry = metric.update(
150 &HammingScoreInput::new(x, y), &MetricMetadata::fake(),
152 );
153 assert_eq!(75.0, metric.value());
154 }
155}