burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use crate::tensor::backend::Backend;
use crate::train::metric;
use burn_tensor::Tensor;

#[derive(new)]
pub struct ClassificationOutput<B: Backend> {
    pub loss: Tensor<B, 1>,
    pub output: Tensor<B, 2>,
    pub targets: Tensor<B, 2>,
}

impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::LossMetric {
    fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn {
        self.update(&item.loss)
    }
    fn clear(&mut self) {
        <metric::LossMetric as metric::Metric<Tensor<B, 1>>>::clear(self);
    }
}

impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMetric {
    fn update(&mut self, item: &ClassificationOutput<B>) -> metric::MetricStateDyn {
        self.update(&(item.output.clone(), item.targets.clone()))
    }

    fn clear(&mut self) {
        <metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B, 2>)>>::clear(self);
    }
}