burn_train/learner/
classification.rs1use crate::metric::TopKAccuracyInput;
2use crate::metric::{
3 AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, processor::ItemLazy,
4};
5use burn_core::tensor::backend::Backend;
6use burn_core::tensor::{Int, Tensor, Transaction};
7use burn_ndarray::NdArray;
8
9#[derive(new)]
11pub struct ClassificationOutput<B: Backend> {
12 pub loss: Tensor<B, 1>,
14
15 pub output: Tensor<B, 2>,
17
18 pub targets: Tensor<B, 1, Int>,
20}
21
22impl<B: Backend> ItemLazy for ClassificationOutput<B> {
23 type ItemSync = ClassificationOutput<NdArray>;
24
25 fn sync(self) -> Self::ItemSync {
26 let [output, loss, targets] = Transaction::default()
27 .register(self.output)
28 .register(self.loss)
29 .register(self.targets)
30 .execute()
31 .try_into()
32 .expect("Correct amount of tensor data");
33
34 let device = &Default::default();
35
36 ClassificationOutput {
37 output: Tensor::from_data(output, device),
38 loss: Tensor::from_data(loss, device),
39 targets: Tensor::from_data(targets, device),
40 }
41 }
42}
43
44impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
45 fn adapt(&self) -> AccuracyInput<B> {
46 AccuracyInput::new(self.output.clone(), self.targets.clone())
47 }
48}
49
50impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
51 fn adapt(&self) -> LossInput<B> {
52 LossInput::new(self.loss.clone())
53 }
54}
55
56impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {
57 fn adapt(&self) -> TopKAccuracyInput<B> {
58 TopKAccuracyInput::new(self.output.clone(), self.targets.clone())
59 }
60}
61
62impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
63 fn adapt(&self) -> ConfusionStatsInput<B> {
64 let [_, num_classes] = self.output.dims();
65 if num_classes > 1 {
66 ConfusionStatsInput::new(
67 self.output.clone(),
68 self.targets.clone().one_hot(num_classes).bool(),
69 )
70 } else {
71 ConfusionStatsInput::new(
72 self.output.clone(),
73 self.targets.clone().unsqueeze_dim(1).bool(),
74 )
75 }
76 }
77}
78
79#[derive(new)]
81pub struct MultiLabelClassificationOutput<B: Backend> {
82 pub loss: Tensor<B, 1>,
84
85 pub output: Tensor<B, 2>,
87
88 pub targets: Tensor<B, 2, Int>,
90}
91
92impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
93 type ItemSync = MultiLabelClassificationOutput<NdArray>;
94
95 fn sync(self) -> Self::ItemSync {
96 let [output, loss, targets] = Transaction::default()
97 .register(self.output)
98 .register(self.loss)
99 .register(self.targets)
100 .execute()
101 .try_into()
102 .expect("Correct amount of tensor data");
103
104 let device = &Default::default();
105
106 MultiLabelClassificationOutput {
107 output: Tensor::from_data(output, device),
108 loss: Tensor::from_data(loss, device),
109 targets: Tensor::from_data(targets, device),
110 }
111 }
112}
113
114impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
115 fn adapt(&self) -> HammingScoreInput<B> {
116 HammingScoreInput::new(self.output.clone(), self.targets.clone())
117 }
118}
119
120impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
121 fn adapt(&self) -> LossInput<B> {
122 LossInput::new(self.loss.clone())
123 }
124}
125
126impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
127 fn adapt(&self) -> ConfusionStatsInput<B> {
128 ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
129 }
130}