nove_metric 0.1.2

An easy-to-use, lightweight deep learning library wrapped around Candle Tensor.
Documentation
use nove_tensor::{DType, Tensor};

use crate::{EvaluationMetric, Metric, MetricError, MetricValue};

/// Accuracy metric.
///
/// # Notes
/// * It should be used for binary/multi-class classification tasks.
/// * It requires the output of the model to have the shape `(batch_size, num_classes)` and
///   the target with dtype `u32` to have the shape `(batch_size,)`, which contains the real class indices (not one-hot encoding).
///
/// # Fields
/// * `name` - The name of the metric.
/// * `value` - The value of the metric.
///
/// # Examples
/// ```
/// use nove::tensor::{Device, Tensor};
/// use nove::metric::{MetricValue, AccuracyMetric, EvaluationMetric, Metric};
///
/// let device = Device::cpu();
/// let output = Tensor::from_data(vec![
///     vec![0.1f64, 0.2f64, 0.7f64],
///     vec![0.3f64, 0.4f64, 0.3f64],
///     vec![0.6f64, 0.1f64, 0.3f64],
/// ], &device, false).unwrap();
/// let target = Tensor::from_data(vec![2u32, 1u32, 1u32], &device, false).unwrap();
///
/// let mut metric = AccuracyMetric::new();
/// metric.evaluate(&output, &target).unwrap();
/// let accuracy = metric.value().unwrap();
/// assert_eq!(accuracy, MetricValue::Scalar(0.6666666666666666));
/// ```
#[derive(Debug, Clone)]
pub struct AccuracyMetric {
    name: String,
    value: MetricValue,
    total_samples: usize,
    correct_samples: usize,
}

impl AccuracyMetric {
    pub fn new() -> Self {
        Self {
            name: "Accuracy".to_string(),
            value: MetricValue::Scalar(0.0),
            total_samples: 0,
            correct_samples: 0,
        }
    }
}

impl Metric for AccuracyMetric {
    fn name(&self) -> Result<String, MetricError> {
        Ok(self.name.clone())
    }

    fn value(&self) -> Result<MetricValue, MetricError> {
        Ok(self.value.clone())
    }

    fn update(&mut self, value: MetricValue) -> Result<(), MetricError> {
        self.value = value;
        Ok(())
    }

    fn clear(&mut self) -> Result<(), MetricError> {
        self.value = MetricValue::Scalar(0.0);
        self.total_samples = 0;
        self.correct_samples = 0;
        Ok(())
    }
}

impl EvaluationMetric for AccuracyMetric {
    fn evaluate(&mut self, output: &Tensor, target: &Tensor) -> Result<(), MetricError> {
        let correct = output
            .argmax((1, false))?
            .eq(&target)?
            .to_dtype(&DType::F64)?;

        let batch_size = output.shape()?.dims()[0];
        let correct_count = correct.sum(None)?.to_scalar::<f64>()? as usize;

        self.total_samples += batch_size;
        self.correct_samples += correct_count;

        let new_acc = if self.total_samples > 0 {
            (self.correct_samples as f64) / (self.total_samples as f64)
        } else {
            0.0
        };

        self.value = MetricValue::Scalar(new_acc);

        Ok(())
    }
}