burn_train/metric/
hamming.rs

1use core::marker::PhantomData;
2use std::sync::Arc;
3
4use super::state::{FormatOptions, NumericMetricState};
5use super::{MetricEntry, MetricMetadata};
6use crate::metric::{Metric, MetricName, Numeric};
7use burn_core::tensor::{ElementConversion, Int, Tensor, activation::sigmoid, backend::Backend};
8
9/// The hamming score, sometimes referred to as multi-label or label-based accuracy.
10#[derive(Clone)]
11pub struct HammingScore<B: Backend> {
12    name: MetricName,
13    state: NumericMetricState,
14    threshold: f32,
15    sigmoid: bool,
16    _b: PhantomData<B>,
17}
18
19/// The [hamming score](HammingScore) input type.
20#[derive(new)]
21pub struct HammingScoreInput<B: Backend> {
22    outputs: Tensor<B, 2>,
23    targets: Tensor<B, 2, Int>,
24}
25
26impl<B: Backend> HammingScore<B> {
27    /// Creates the metric.
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    fn update_name(&mut self) {
33        self.name = Arc::new(format!("Hamming Score @ Threshold({})", self.threshold));
34    }
35
36    /// Sets the threshold.
37    pub fn with_threshold(mut self, threshold: f32) -> Self {
38        self.threshold = threshold;
39        self.update_name();
40        self
41    }
42
43    /// Sets the sigmoid activation function usage.
44    pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
45        self.sigmoid = sigmoid;
46        self.update_name();
47        self
48    }
49}
50
51impl<B: Backend> Default for HammingScore<B> {
52    /// Creates a new metric instance with default values.
53    fn default() -> Self {
54        let threshold = 0.5;
55        let name = Arc::new(format!("Hamming Score @ Threshold({})", threshold));
56
57        Self {
58            name,
59            state: NumericMetricState::default(),
60            threshold,
61            sigmoid: false,
62            _b: PhantomData,
63        }
64    }
65}
66
67impl<B: Backend> Metric for HammingScore<B> {
68    type Input = HammingScoreInput<B>;
69
70    fn update(&mut self, input: &HammingScoreInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
71        let [batch_size, _n_classes] = input.outputs.dims();
72
73        let targets = input.targets.clone();
74
75        let mut outputs = input.outputs.clone();
76
77        if self.sigmoid {
78            outputs = sigmoid(outputs);
79        }
80
81        let score = outputs
82            .greater_elem(self.threshold)
83            .equal(targets.bool())
84            .float()
85            .mean()
86            .into_scalar()
87            .elem::<f64>();
88
89        self.state.update(
90            100.0 * score,
91            batch_size,
92            FormatOptions::new(self.name()).unit("%").precision(2),
93        )
94    }
95
96    fn clear(&mut self) {
97        self.state.reset()
98    }
99
100    fn name(&self) -> MetricName {
101        self.name.clone()
102    }
103}
104
105impl<B: Backend> Numeric for HammingScore<B> {
106    fn value(&self) -> super::NumericEntry {
107        self.state.value()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::TestBackend;
115
116    #[test]
117    fn test_hamming_score() {
118        let device = Default::default();
119        let mut metric = HammingScore::<TestBackend>::new();
120
121        let x = Tensor::from_data(
122            [
123                [0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1]
124                [0.43, 0.31, 0.21, 0.63, 0.53], //               [0, 0, 0, 1, 1]
125                [0.44, 0.25, 0.71, 0.39, 0.73], //               [0, 0, 1, 0, 1]
126                [0.49, 0.37, 0.68, 0.39, 0.31], //               [0, 0, 1, 0, 0]
127            ],
128            &device,
129        );
130        let y = Tensor::from_data(
131            [
132                [0, 1, 0, 1, 1],
133                [0, 0, 0, 1, 1],
134                [0, 0, 1, 0, 1],
135                [0, 0, 1, 0, 0],
136            ],
137            &device,
138        );
139
140        let _entry = metric.update(
141            &HammingScoreInput::new(x.clone(), y.clone()),
142            &MetricMetadata::fake(),
143        );
144        assert_eq!(100.0, metric.value().current());
145
146        // Invert all targets: y = (1 - y)
147        let y = y.neg().add_scalar(1);
148        let _entry = metric.update(
149            &HammingScoreInput::new(x.clone(), y), // invert targets (1 - y)
150            &MetricMetadata::fake(),
151        );
152        assert_eq!(0.0, metric.value().current());
153
154        // Invert 5 target values -> 1 - (5/20) = 0.75
155        let y = Tensor::from_data(
156            [
157                [0, 1, 1, 0, 1],
158                [0, 0, 0, 0, 1],
159                [0, 0, 0, 0, 1],
160                [0, 1, 1, 0, 0],
161            ],
162            &device,
163        );
164        let _entry = metric.update(
165            &HammingScoreInput::new(x, y), // invert targets (1 - y)
166            &MetricMetadata::fake(),
167        );
168        assert_eq!(75.0, metric.value().current());
169    }
170
171    #[test]
172    fn test_parameterized_unique_name() {
173        let metric_a = HammingScore::<TestBackend>::new().with_threshold(0.5);
174        let metric_b = HammingScore::<TestBackend>::new().with_threshold(0.75);
175        let metric_c = HammingScore::<TestBackend>::new().with_threshold(0.5);
176
177        assert_ne!(metric_a.name(), metric_b.name());
178        assert_eq!(metric_a.name(), metric_c.name());
179    }
180}