burn_train/metric/
hamming.rs

1use core::marker::PhantomData;
2
3use super::state::{FormatOptions, NumericMetricState};
4use super::{MetricEntry, MetricMetadata};
5use crate::metric::{Metric, Numeric};
6use burn_core::tensor::{activation::sigmoid, backend::Backend, ElementConversion, Int, Tensor};
7
8/// The hamming score, sometimes referred to as multi-label or label-based accuracy.
9pub struct HammingScore<B: Backend> {
10    state: NumericMetricState,
11    threshold: f32,
12    sigmoid: bool,
13    _b: PhantomData<B>,
14}
15
16/// The [hamming score](HammingScore) input type.
17#[derive(new)]
18pub struct HammingScoreInput<B: Backend> {
19    outputs: Tensor<B, 2>,
20    targets: Tensor<B, 2, Int>,
21}
22
23impl<B: Backend> HammingScore<B> {
24    /// Creates the metric.
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Sets the threshold.
30    pub fn with_threshold(mut self, threshold: f32) -> Self {
31        self.threshold = threshold;
32        self
33    }
34
35    /// Sets the sigmoid activation function usage.
36    pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
37        self.sigmoid = sigmoid;
38        self
39    }
40}
41
42impl<B: Backend> Default for HammingScore<B> {
43    /// Creates a new metric instance with default values.
44    fn default() -> Self {
45        Self {
46            state: NumericMetricState::default(),
47            threshold: 0.5,
48            sigmoid: false,
49            _b: PhantomData,
50        }
51    }
52}
53
54impl<B: Backend> Metric for HammingScore<B> {
55    const NAME: &'static str = "Hamming Score";
56
57    type Input = HammingScoreInput<B>;
58
59    fn update(&mut self, input: &HammingScoreInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
60        let [batch_size, _n_classes] = input.outputs.dims();
61
62        let targets = input.targets.clone();
63
64        let mut outputs = input.outputs.clone();
65
66        if self.sigmoid {
67            outputs = sigmoid(outputs);
68        }
69
70        let score = outputs
71            .greater_elem(self.threshold)
72            .equal(targets.bool())
73            .float()
74            .mean()
75            .into_scalar()
76            .elem::<f64>();
77
78        self.state.update(
79            100.0 * score,
80            batch_size,
81            FormatOptions::new(Self::NAME).unit("%").precision(2),
82        )
83    }
84
85    fn clear(&mut self) {
86        self.state.reset()
87    }
88}
89
90impl<B: Backend> Numeric for HammingScore<B> {
91    fn value(&self) -> f64 {
92        self.state.value()
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::TestBackend;
100
101    #[test]
102    fn test_hamming_score() {
103        let device = Default::default();
104        let mut metric = HammingScore::<TestBackend>::new();
105
106        let x = Tensor::from_data(
107            [
108                [0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1]
109                [0.43, 0.31, 0.21, 0.63, 0.53], //               [0, 0, 0, 1, 1]
110                [0.44, 0.25, 0.71, 0.39, 0.73], //               [0, 0, 1, 0, 1]
111                [0.49, 0.37, 0.68, 0.39, 0.31], //               [0, 0, 1, 0, 0]
112            ],
113            &device,
114        );
115        let y = Tensor::from_data(
116            [
117                [0, 1, 0, 1, 1],
118                [0, 0, 0, 1, 1],
119                [0, 0, 1, 0, 1],
120                [0, 0, 1, 0, 0],
121            ],
122            &device,
123        );
124
125        let _entry = metric.update(
126            &HammingScoreInput::new(x.clone(), y.clone()),
127            &MetricMetadata::fake(),
128        );
129        assert_eq!(100.0, metric.value());
130
131        // Invert all targets: y = (1 - y)
132        let y = y.neg().add_scalar(1);
133        let _entry = metric.update(
134            &HammingScoreInput::new(x.clone(), y), // invert targets (1 - y)
135            &MetricMetadata::fake(),
136        );
137        assert_eq!(0.0, metric.value());
138
139        // Invert 5 target values -> 1 - (5/20) = 0.75
140        let y = Tensor::from_data(
141            [
142                [0, 1, 1, 0, 1],
143                [0, 0, 0, 0, 1],
144                [0, 0, 0, 0, 1],
145                [0, 1, 1, 0, 0],
146            ],
147            &device,
148        );
149        let _entry = metric.update(
150            &HammingScoreInput::new(x, y), // invert targets (1 - y)
151            &MetricMetadata::fake(),
152        );
153        assert_eq!(75.0, metric.value());
154    }
155}