use nove_tensor::{DType, Tensor};
use crate::{EvaluationMetric, Metric, MetricError, MetricValue};
#[derive(Debug, Clone)]
pub struct AccuracyMetric {
name: String,
value: MetricValue,
total_samples: usize,
correct_samples: usize,
}
impl AccuracyMetric {
pub fn new() -> Self {
Self {
name: "Accuracy".to_string(),
value: MetricValue::Scalar(0.0),
total_samples: 0,
correct_samples: 0,
}
}
}
impl Metric for AccuracyMetric {
fn name(&self) -> Result<String, MetricError> {
Ok(self.name.clone())
}
fn value(&self) -> Result<MetricValue, MetricError> {
Ok(self.value.clone())
}
fn update(&mut self, value: MetricValue) -> Result<(), MetricError> {
self.value = value;
Ok(())
}
fn clear(&mut self) -> Result<(), MetricError> {
self.value = MetricValue::Scalar(0.0);
self.total_samples = 0;
self.correct_samples = 0;
Ok(())
}
}
impl EvaluationMetric for AccuracyMetric {
fn evaluate(&mut self, output: &Tensor, target: &Tensor) -> Result<(), MetricError> {
let correct = output
.argmax((1, false))?
.eq(&target)?
.to_dtype(&DType::F64)?;
let batch_size = output.shape()?.dims()[0];
let correct_count = correct.sum(None)?.to_scalar::<f64>()? as usize;
self.total_samples += batch_size;
self.correct_samples += correct_count;
let new_acc = if self.total_samples > 0 {
(self.correct_samples as f64) / (self.total_samples as f64)
} else {
0.0
};
self.value = MetricValue::Scalar(new_acc);
Ok(())
}
}