burn_train/metric/
loss.rs1use std::sync::Arc;
2
3use super::MetricEntry;
4use super::MetricMetadata;
5use super::state::FormatOptions;
6use super::state::NumericMetricState;
7use crate::metric::MetricName;
8use crate::metric::{Metric, Numeric};
9use burn_core::tensor::Tensor;
10use burn_core::tensor::backend::Backend;
11
12#[derive(Clone)]
14pub struct LossMetric<B: Backend> {
15 name: Arc<String>,
16 state: NumericMetricState,
17 _b: B,
18}
19
20#[derive(new)]
22pub struct LossInput<B: Backend> {
23 tensor: Tensor<B, 1>,
24}
25
26impl<B: Backend> Default for LossMetric<B> {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl<B: Backend> LossMetric<B> {
33 pub fn new() -> Self {
35 Self {
36 name: Arc::new("Loss".to_string()),
37 state: NumericMetricState::default(),
38 _b: Default::default(),
39 }
40 }
41}
42
43impl<B: Backend> Metric for LossMetric<B> {
44 type Input = LossInput<B>;
45
46 fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
47 let [batch_size] = loss.tensor.dims();
48 let loss = loss
49 .tensor
50 .clone()
51 .mean()
52 .into_data()
53 .iter::<f64>()
54 .next()
55 .unwrap();
56
57 self.state.update(
58 loss,
59 batch_size,
60 FormatOptions::new(self.name()).precision(2),
61 )
62 }
63
64 fn clear(&mut self) {
65 self.state.reset()
66 }
67
68 fn name(&self) -> MetricName {
69 self.name.clone()
70 }
71}
72
73impl<B: Backend> Numeric for LossMetric<B> {
74 fn value(&self) -> super::NumericEntry {
75 self.state.value()
76 }
77}