burn_train/learner/
classification.rs1use crate::metric::TopKAccuracyInput;
2use crate::metric::{
3 AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, PerplexityInput,
4 processor::ItemLazy,
5};
6use burn_core::tensor::backend::Backend;
7use burn_core::tensor::{Int, Tensor, Transaction};
8use burn_ndarray::NdArray;
9
10#[derive(new)]
12pub struct ClassificationOutput<B: Backend> {
13 pub loss: Tensor<B, 1>,
15
16 pub output: Tensor<B, 2>,
18
19 pub targets: Tensor<B, 1, Int>,
21}
22
23impl<B: Backend> ItemLazy for ClassificationOutput<B> {
24 type ItemSync = ClassificationOutput<NdArray>;
25
26 fn sync(self) -> Self::ItemSync {
27 let [output, loss, targets] = Transaction::default()
28 .register(self.output)
29 .register(self.loss)
30 .register(self.targets)
31 .execute()
32 .try_into()
33 .expect("Correct amount of tensor data");
34
35 let device = &Default::default();
36
37 ClassificationOutput {
38 output: Tensor::from_data(output, device),
39 loss: Tensor::from_data(loss, device),
40 targets: Tensor::from_data(targets, device),
41 }
42 }
43}
44
45impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
46 fn adapt(&self) -> AccuracyInput<B> {
47 AccuracyInput::new(self.output.clone(), self.targets.clone())
48 }
49}
50
51impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
52 fn adapt(&self) -> LossInput<B> {
53 LossInput::new(self.loss.clone())
54 }
55}
56
57impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {
58 fn adapt(&self) -> TopKAccuracyInput<B> {
59 TopKAccuracyInput::new(self.output.clone(), self.targets.clone())
60 }
61}
62
63impl<B: Backend> Adaptor<PerplexityInput<B>> for ClassificationOutput<B> {
64 fn adapt(&self) -> PerplexityInput<B> {
65 PerplexityInput::new(self.output.clone(), self.targets.clone())
66 }
67}
68
69impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
70 fn adapt(&self) -> ConfusionStatsInput<B> {
71 let [_, num_classes] = self.output.dims();
72 if num_classes > 1 {
73 ConfusionStatsInput::new(
74 self.output.clone(),
75 self.targets.clone().one_hot(num_classes).bool(),
76 )
77 } else {
78 ConfusionStatsInput::new(
79 self.output.clone(),
80 self.targets.clone().unsqueeze_dim(1).bool(),
81 )
82 }
83 }
84}
85
86#[derive(new)]
88pub struct MultiLabelClassificationOutput<B: Backend> {
89 pub loss: Tensor<B, 1>,
91
92 pub output: Tensor<B, 2>,
94
95 pub targets: Tensor<B, 2, Int>,
97}
98
99impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
100 type ItemSync = MultiLabelClassificationOutput<NdArray>;
101
102 fn sync(self) -> Self::ItemSync {
103 let [output, loss, targets] = Transaction::default()
104 .register(self.output)
105 .register(self.loss)
106 .register(self.targets)
107 .execute()
108 .try_into()
109 .expect("Correct amount of tensor data");
110
111 let device = &Default::default();
112
113 MultiLabelClassificationOutput {
114 output: Tensor::from_data(output, device),
115 loss: Tensor::from_data(loss, device),
116 targets: Tensor::from_data(targets, device),
117 }
118 }
119}
120
121impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
122 fn adapt(&self) -> HammingScoreInput<B> {
123 HammingScoreInput::new(self.output.clone(), self.targets.clone())
124 }
125}
126
127impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
128 fn adapt(&self) -> LossInput<B> {
129 LossInput::new(self.loss.clone())
130 }
131}
132
133impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
134 fn adapt(&self) -> ConfusionStatsInput<B> {
135 ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
136 }
137}