burn_train/learner/
regression.rs1use crate::metric::processor::ItemLazy;
2use crate::metric::{Adaptor, LossInput};
3use burn_core::tensor::backend::Backend;
4use burn_core::tensor::{Tensor, Transaction};
5use burn_ndarray::NdArray;
6
7#[derive(new)]
9pub struct RegressionOutput<B: Backend> {
10 pub loss: Tensor<B, 1>,
12
13 pub output: Tensor<B, 2>,
15
16 pub targets: Tensor<B, 2>,
18}
19
20impl<B: Backend> Adaptor<LossInput<B>> for RegressionOutput<B> {
21 fn adapt(&self) -> LossInput<B> {
22 LossInput::new(self.loss.clone())
23 }
24}
25
26impl<B: Backend> ItemLazy for RegressionOutput<B> {
27 type ItemSync = RegressionOutput<NdArray>;
28
29 fn sync(self) -> Self::ItemSync {
30 let [output, loss, targets] = Transaction::default()
31 .register(self.output)
32 .register(self.loss)
33 .register(self.targets)
34 .execute()
35 .try_into()
36 .expect("Correct amount of tensor data");
37
38 let device = &Default::default();
39
40 RegressionOutput {
41 output: Tensor::from_data(output, device),
42 loss: Tensor::from_data(loss, device),
43 targets: Tensor::from_data(targets, device),
44 }
45 }
46}