use core::marker::PhantomData;
use super::MetricMetadata;
use super::state::{FormatOptions, NumericMetricState};
use crate::metric::{Metric, MetricAttributes, MetricName, Numeric, SerializedEntry};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{ElementConversion, Int, Tensor};
#[derive(Clone)]
pub struct AccuracyMetric<B: Backend> {
name: MetricName,
state: NumericMetricState,
pad_token: Option<usize>,
_b: PhantomData<B>,
}
#[derive(new)]
pub struct AccuracyInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Default for AccuracyMetric<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> AccuracyMetric<B> {
pub fn new() -> Self {
Self {
name: MetricName::new("Accuracy".to_string()),
state: Default::default(),
pad_token: Default::default(),
_b: PhantomData,
}
}
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}
impl<B: Backend> Metric for AccuracyMetric<B> {
type Input = AccuracyInput<B>;
fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {
let targets = input.targets.clone();
let outputs = input.outputs.clone();
let [batch_size, _n_classes] = outputs.dims();
let outputs = outputs.argmax(1).reshape([batch_size]);
let accuracy = match self.pad_token {
Some(pad_token) => {
let mask = targets.clone().equal_elem(pad_token as i64);
let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);
let num_pad = mask.float().sum();
let acc = matches.sum() / (num_pad.neg() + batch_size as f32);
acc.into_scalar().elem::<f64>()
}
None => {
outputs
.equal(targets)
.int()
.sum()
.into_scalar()
.elem::<f64>()
/ batch_size as f64
}
};
self.state.update(
100.0 * accuracy,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
fn name(&self) -> MetricName {
self.name.clone()
}
fn attributes(&self) -> MetricAttributes {
super::NumericAttributes {
unit: Some("%".to_string()),
higher_is_better: true,
}
.into()
}
}
impl<B: Backend> Numeric for AccuracyMetric<B> {
fn value(&self) -> super::NumericEntry {
self.state.current_value()
}
fn running_value(&self) -> super::NumericEntry {
self.state.running_value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_accuracy_without_padding() {
let device = Default::default();
let mut metric = AccuracyMetric::<TestBackend>::new();
let input = AccuracyInput::new(
Tensor::from_data(
[
[0.0, 0.2, 0.8], [1.0, 2.0, 0.5], [0.4, 0.1, 0.2], [0.6, 0.7, 0.2], ],
&device,
),
Tensor::from_data([2, 2, 1, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
#[test]
fn test_accuracy_with_padding() {
let device = Default::default();
let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);
let input = AccuracyInput::new(
Tensor::from_data(
[
[0.0, 0.2, 0.8, 0.0], [1.0, 2.0, 0.5, 0.0], [0.4, 0.1, 0.2, 0.0], [0.6, 0.7, 0.2, 0.0], [0.0, 0.1, 0.2, 5.0], [0.0, 0.1, 0.2, 0.0], [0.6, 0.0, 0.2, 0.0], ],
&device,
),
Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(50.0, metric.value().current());
}
}