burn_train/learner/
classification.rs1use crate::metric::{
2 processor::ItemLazy, AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput,
3};
4use burn_core::tensor::backend::Backend;
5use burn_core::tensor::{Int, Tensor, Transaction};
6use burn_ndarray::NdArray;
7
8#[derive(new)]
10pub struct ClassificationOutput<B: Backend> {
11 pub loss: Tensor<B, 1>,
13
14 pub output: Tensor<B, 2>,
16
17 pub targets: Tensor<B, 1, Int>,
19}
20
21impl<B: Backend> ItemLazy for ClassificationOutput<B> {
22 type ItemSync = ClassificationOutput<NdArray>;
23
24 fn sync(self) -> Self::ItemSync {
25 let [output, loss, targets] = Transaction::default()
26 .register(self.output)
27 .register(self.loss)
28 .register(self.targets)
29 .execute()
30 .try_into()
31 .expect("Correct amount of tensor data");
32
33 let device = &Default::default();
34
35 ClassificationOutput {
36 output: Tensor::from_data(output, device),
37 loss: Tensor::from_data(loss, device),
38 targets: Tensor::from_data(targets, device),
39 }
40 }
41}
42
43impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
44 fn adapt(&self) -> AccuracyInput<B> {
45 AccuracyInput::new(self.output.clone(), self.targets.clone())
46 }
47}
48
49impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
50 fn adapt(&self) -> LossInput<B> {
51 LossInput::new(self.loss.clone())
52 }
53}
54
55impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
56 fn adapt(&self) -> ConfusionStatsInput<B> {
57 let [_, num_classes] = self.output.dims();
58 if num_classes > 1 {
59 ConfusionStatsInput::new(
60 self.output.clone(),
61 self.targets.clone().one_hot(num_classes).bool(),
62 )
63 } else {
64 ConfusionStatsInput::new(
65 self.output.clone(),
66 self.targets.clone().unsqueeze_dim(1).bool(),
67 )
68 }
69 }
70}
71
72#[derive(new)]
74pub struct MultiLabelClassificationOutput<B: Backend> {
75 pub loss: Tensor<B, 1>,
77
78 pub output: Tensor<B, 2>,
80
81 pub targets: Tensor<B, 2, Int>,
83}
84
85impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
86 type ItemSync = MultiLabelClassificationOutput<NdArray>;
87
88 fn sync(self) -> Self::ItemSync {
89 let [output, loss, targets] = Transaction::default()
90 .register(self.output)
91 .register(self.loss)
92 .register(self.targets)
93 .execute()
94 .try_into()
95 .expect("Correct amount of tensor data");
96
97 let device = &Default::default();
98
99 MultiLabelClassificationOutput {
100 output: Tensor::from_data(output, device),
101 loss: Tensor::from_data(loss, device),
102 targets: Tensor::from_data(targets, device),
103 }
104 }
105}
106
107impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
108 fn adapt(&self) -> HammingScoreInput<B> {
109 HammingScoreInput::new(self.output.clone(), self.targets.clone())
110 }
111}
112
113impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
114 fn adapt(&self) -> LossInput<B> {
115 LossInput::new(self.loss.clone())
116 }
117}
118
119impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
120 fn adapt(&self) -> ConfusionStatsInput<B> {
121 ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
122 }
123}