burn_train/learner/
regression.rs

1use crate::metric::processor::ItemLazy;
2use crate::metric::{Adaptor, LossInput};
3use burn_core::tensor::backend::Backend;
4use burn_core::tensor::{Tensor, Transaction};
5use burn_ndarray::NdArray;
6
7/// Simple regression output adapted for multiple metrics.
8#[derive(new)]
9pub struct RegressionOutput<B: Backend> {
10    /// The loss.
11    pub loss: Tensor<B, 1>,
12
13    /// The output.
14    pub output: Tensor<B, 2>,
15
16    /// The targets.
17    pub targets: Tensor<B, 2>,
18}
19
20impl<B: Backend> Adaptor<LossInput<B>> for RegressionOutput<B> {
21    fn adapt(&self) -> LossInput<B> {
22        LossInput::new(self.loss.clone())
23    }
24}
25
26impl<B: Backend> ItemLazy for RegressionOutput<B> {
27    type ItemSync = RegressionOutput<NdArray>;
28
29    fn sync(self) -> Self::ItemSync {
30        let [output, loss, targets] = Transaction::default()
31            .register(self.output)
32            .register(self.loss)
33            .register(self.targets)
34            .execute()
35            .try_into()
36            .expect("Correct amount of tensor data");
37
38        let device = &Default::default();
39
40        RegressionOutput {
41            output: Tensor::from_data(output, device),
42            loss: Tensor::from_data(loss, device),
43            targets: Tensor::from_data(targets, device),
44        }
45    }
46}