burn_train/metric/
loss.rs

1use std::sync::Arc;
2
3use super::MetricEntry;
4use super::MetricMetadata;
5use super::state::FormatOptions;
6use super::state::NumericMetricState;
7use crate::metric::MetricName;
8use crate::metric::{Metric, Numeric};
9use burn_core::tensor::Tensor;
10use burn_core::tensor::backend::Backend;
11
12/// The loss metric.
13#[derive(Clone)]
14pub struct LossMetric<B: Backend> {
15    name: Arc<String>,
16    state: NumericMetricState,
17    _b: B,
18}
19
20/// The [loss metric](LossMetric) input type.
21#[derive(new)]
22pub struct LossInput<B: Backend> {
23    tensor: Tensor<B, 1>,
24}
25
26impl<B: Backend> Default for LossMetric<B> {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl<B: Backend> LossMetric<B> {
33    /// Create the metric.
34    pub fn new() -> Self {
35        Self {
36            name: Arc::new("Loss".to_string()),
37            state: NumericMetricState::default(),
38            _b: Default::default(),
39        }
40    }
41}
42
43impl<B: Backend> Metric for LossMetric<B> {
44    type Input = LossInput<B>;
45
46    fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
47        let [batch_size] = loss.tensor.dims();
48        let loss = loss
49            .tensor
50            .clone()
51            .mean()
52            .into_data()
53            .iter::<f64>()
54            .next()
55            .unwrap();
56
57        self.state.update(
58            loss,
59            batch_size,
60            FormatOptions::new(self.name()).precision(2),
61        )
62    }
63
64    fn clear(&mut self) {
65        self.state.reset()
66    }
67
68    fn name(&self) -> MetricName {
69        self.name.clone()
70    }
71}
72
73impl<B: Backend> Numeric for LossMetric<B> {
74    fn value(&self) -> super::NumericEntry {
75        self.state.value()
76    }
77}