burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use super::RunningMetricResult;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use crate::train::metric::{Metric, MetricStateDyn, Numeric};

pub struct AccuracyMetric {
    current: f64,
    count: usize,
    total: usize,
}

impl AccuracyMetric {
    pub fn new() -> Self {
        Self {
            count: 0,
            current: 0.0,
            total: 0,
        }
    }
}

impl Default for AccuracyMetric {
    fn default() -> Self {
        Self::new()
    }
}

impl Numeric for AccuracyMetric {
    fn value(&self) -> f64 {
        self.current * 100.0
    }
}

impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B, 2>)> for AccuracyMetric {
    fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B, 2>)) -> MetricStateDyn {
        let (outputs, targets) = batch;
        let logits_outputs = outputs.argmax(1).to_device(B::Device::default());
        let logits_targets = targets.argmax(1).to_device(B::Device::default());
        let count_current = logits_targets.shape().dims[0];

        let total_current = logits_outputs
            .equal(&logits_targets)
            .to_int()
            .sum()
            .to_data()
            .value[0] as usize;

        self.count += count_current;
        self.total += total_current;
        self.current = total_current as f64 / count_current as f64;

        let name = String::from("Accurracy");
        let running = self.total as f64 / self.count as f64;
        let raw_running = format!("{}", running);
        let raw_current = format!("{}", self.current);
        let formatted = format!(
            "running {:.2} % current {:.2} %",
            100.0 * running,
            100.0 * self.current
        );

        Box::new(RunningMetricResult {
            name,
            formatted,
            raw_running,
            raw_current,
        })
    }

    fn clear(&mut self) {
        self.count = 0;
        self.total = 0;
        self.current = 0.0;
    }
}