Skip to main content

burn_train/metric/
hamming.rs

1use core::marker::PhantomData;
2use std::sync::Arc;
3
4use super::state::{FormatOptions, NumericMetricState};
5use super::{MetricMetadata, SerializedEntry};
6use crate::metric::{
7    Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,
8};
9use burn_core::tensor::{ElementConversion, Int, Tensor, activation::sigmoid, backend::Backend};
10
11/// The hamming score, sometimes referred to as multi-label or label-based accuracy.
12#[derive(Clone)]
13pub struct HammingScore<B: Backend> {
14    name: MetricName,
15    state: NumericMetricState,
16    threshold: f32,
17    sigmoid: bool,
18    _b: PhantomData<B>,
19}
20
21/// The [hamming score](HammingScore) input type.
22#[derive(new)]
23pub struct HammingScoreInput<B: Backend> {
24    outputs: Tensor<B, 2>,
25    targets: Tensor<B, 2, Int>,
26}
27
28impl<B: Backend> HammingScore<B> {
29    /// Creates the metric.
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    fn update_name(&mut self) {
35        self.name = Arc::new(format!("Hamming Score @ Threshold({})", self.threshold));
36    }
37
38    /// Sets the threshold.
39    pub fn with_threshold(mut self, threshold: f32) -> Self {
40        self.threshold = threshold;
41        self.update_name();
42        self
43    }
44
45    /// Sets the sigmoid activation function usage.
46    pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
47        self.sigmoid = sigmoid;
48        self.update_name();
49        self
50    }
51}
52
53impl<B: Backend> Default for HammingScore<B> {
54    /// Creates a new metric instance with default values.
55    fn default() -> Self {
56        let threshold = 0.5;
57        let name = Arc::new(format!("Hamming Score @ Threshold({})", threshold));
58
59        Self {
60            name,
61            state: NumericMetricState::default(),
62            threshold,
63            sigmoid: false,
64            _b: PhantomData,
65        }
66    }
67}
68
69impl<B: Backend> Metric for HammingScore<B> {
70    type Input = HammingScoreInput<B>;
71
72    fn update(
73        &mut self,
74        input: &HammingScoreInput<B>,
75        _metadata: &MetricMetadata,
76    ) -> SerializedEntry {
77        let [batch_size, _n_classes] = input.outputs.dims();
78
79        let targets = input.targets.clone();
80
81        let mut outputs = input.outputs.clone();
82
83        if self.sigmoid {
84            outputs = sigmoid(outputs);
85        }
86
87        let score = outputs
88            .greater_elem(self.threshold)
89            .equal(targets.bool())
90            .float()
91            .mean()
92            .into_scalar()
93            .elem::<f64>();
94
95        self.state.update(
96            100.0 * score,
97            batch_size,
98            FormatOptions::new(self.name()).unit("%").precision(2),
99        )
100    }
101
102    fn clear(&mut self) {
103        self.state.reset()
104    }
105
106    fn name(&self) -> MetricName {
107        self.name.clone()
108    }
109
110    fn attributes(&self) -> MetricAttributes {
111        NumericAttributes {
112            unit: Some("%".to_string()),
113            higher_is_better: true,
114        }
115        .into()
116    }
117}
118
119impl<B: Backend> Numeric for HammingScore<B> {
120    fn value(&self) -> NumericEntry {
121        self.state.current_value()
122    }
123
124    fn running_value(&self) -> NumericEntry {
125        self.state.running_value()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::TestBackend;
133
134    #[test]
135    fn test_hamming_score() {
136        let device = Default::default();
137        let mut metric = HammingScore::<TestBackend>::new();
138
139        let x = Tensor::from_data(
140            [
141                [0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1]
142                [0.43, 0.31, 0.21, 0.63, 0.53], //               [0, 0, 0, 1, 1]
143                [0.44, 0.25, 0.71, 0.39, 0.73], //               [0, 0, 1, 0, 1]
144                [0.49, 0.37, 0.68, 0.39, 0.31], //               [0, 0, 1, 0, 0]
145            ],
146            &device,
147        );
148        let y = Tensor::from_data(
149            [
150                [0, 1, 0, 1, 1],
151                [0, 0, 0, 1, 1],
152                [0, 0, 1, 0, 1],
153                [0, 0, 1, 0, 0],
154            ],
155            &device,
156        );
157
158        let _entry = metric.update(
159            &HammingScoreInput::new(x.clone(), y.clone()),
160            &MetricMetadata::fake(),
161        );
162        assert_eq!(100.0, metric.value().current());
163
164        // Invert all targets: y = (1 - y)
165        let y = y.neg().add_scalar(1);
166        let _entry = metric.update(
167            &HammingScoreInput::new(x.clone(), y), // invert targets (1 - y)
168            &MetricMetadata::fake(),
169        );
170        assert_eq!(0.0, metric.value().current());
171
172        // Invert 5 target values -> 1 - (5/20) = 0.75
173        let y = Tensor::from_data(
174            [
175                [0, 1, 1, 0, 1],
176                [0, 0, 0, 0, 1],
177                [0, 0, 0, 0, 1],
178                [0, 1, 1, 0, 0],
179            ],
180            &device,
181        );
182        let _entry = metric.update(
183            &HammingScoreInput::new(x, y), // invert targets (1 - y)
184            &MetricMetadata::fake(),
185        );
186        assert_eq!(75.0, metric.value().current());
187    }
188
189    #[test]
190    fn test_parameterized_unique_name() {
191        let metric_a = HammingScore::<TestBackend>::new().with_threshold(0.5);
192        let metric_b = HammingScore::<TestBackend>::new().with_threshold(0.75);
193        let metric_c = HammingScore::<TestBackend>::new().with_threshold(0.5);
194
195        assert_ne!(metric_a.name(), metric_b.name());
196        assert_eq!(metric_a.name(), metric_c.name());
197    }
198}