use core::marker::PhantomData;
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricEntry, MetricMetadata};
use crate::metric::{Metric, Numeric};
use burn_core::tensor::{activation::sigmoid, backend::Backend, ElementConversion, Int, Tensor};
pub struct HammingScore<B: Backend> {
state: NumericMetricState,
threshold: f32,
sigmoid: bool,
_b: PhantomData<B>,
}
#[derive(new)]
pub struct HammingScoreInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B, 2, Int>,
}
impl<B: Backend> HammingScore<B> {
pub fn new() -> Self {
Self::default()
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
self.sigmoid = sigmoid;
self
}
}
impl<B: Backend> Default for HammingScore<B> {
fn default() -> Self {
Self {
state: NumericMetricState::default(),
threshold: 0.5,
sigmoid: false,
_b: PhantomData,
}
}
}
impl<B: Backend> Metric for HammingScore<B> {
const NAME: &'static str = "Hamming Score";
type Input = HammingScoreInput<B>;
fn update(&mut self, input: &HammingScoreInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
let [batch_size, _n_classes] = input.outputs.dims();
let targets = input.targets.clone();
let mut outputs = input.outputs.clone();
if self.sigmoid {
outputs = sigmoid(outputs);
}
let score = outputs
.greater_elem(self.threshold)
.equal(targets.bool())
.float()
.mean()
.into_scalar()
.elem::<f64>();
self.state.update(
100.0 * score,
batch_size,
FormatOptions::new(Self::NAME).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
}
impl<B: Backend> Numeric for HammingScore<B> {
fn value(&self) -> f64 {
self.state.value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_hamming_score() {
let device = Default::default();
let mut metric = HammingScore::<TestBackend>::new();
let x = Tensor::from_data(
[
[0.32, 0.52, 0.38, 0.68, 0.61], [0.43, 0.31, 0.21, 0.63, 0.53], [0.44, 0.25, 0.71, 0.39, 0.73], [0.49, 0.37, 0.68, 0.39, 0.31], ],
&device,
);
let y = Tensor::from_data(
[
[0, 1, 0, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 1, 0, 1],
[0, 0, 1, 0, 0],
],
&device,
);
let _entry = metric.update(
&HammingScoreInput::new(x.clone(), y.clone()),
&MetricMetadata::fake(),
);
assert_eq!(100.0, metric.value());
let y = y.neg().add_scalar(1);
let _entry = metric.update(
&HammingScoreInput::new(x.clone(), y), &MetricMetadata::fake(),
);
assert_eq!(0.0, metric.value());
let y = Tensor::from_data(
[
[0, 1, 1, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 1, 1, 0, 0],
],
&device,
);
let _entry = metric.update(
&HammingScoreInput::new(x, y), &MetricMetadata::fake(),
);
assert_eq!(75.0, metric.value());
}
}