burn_train/learner/
classification.rs1use crate::metric::{
2 AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput,
3 PerplexityInput, TopKAccuracyInput, processor::ItemLazy,
4};
5use burn_core::tensor::backend::Backend;
6use burn_core::tensor::{Int, Tensor, Transaction};
7use burn_flex::Flex;
8
9#[derive(new)]
21pub struct ClassificationOutput<B: Backend> {
22 pub loss: Tensor<B, 1>,
24
25 pub output: Tensor<B, 2>,
27
28 pub targets: Tensor<B, 1, Int>,
30}
31
32impl<B: Backend> ItemLazy for ClassificationOutput<B> {
33 type ItemSync = ClassificationOutput<Flex>;
35
36 fn sync(self) -> Self::ItemSync {
37 let [output, loss, targets] = Transaction::default()
38 .register(self.output)
39 .register(self.loss)
40 .register(self.targets)
41 .execute()
42 .try_into()
43 .expect("Correct amount of tensor data");
44
45 let device = &Default::default();
46
47 ClassificationOutput {
48 output: Tensor::from_data(output, device),
49 loss: Tensor::from_data(loss, device),
50 targets: Tensor::from_data(targets, device),
51 }
52 }
53}
54
55impl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {
56 fn adapt(&self) -> AccuracyInput<B> {
57 AccuracyInput::new(self.output.clone(), self.targets.clone())
58 }
59}
60
61impl<B: Backend> Adaptor<AurocInput<B>> for ClassificationOutput<B> {
62 fn adapt(&self) -> AurocInput<B> {
63 AurocInput::new(self.output.clone(), self.targets.clone())
64 }
65}
66
67impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
68 fn adapt(&self) -> LossInput<B> {
69 LossInput::new(self.loss.clone())
70 }
71}
72
73impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {
74 fn adapt(&self) -> TopKAccuracyInput<B> {
75 TopKAccuracyInput::new(self.output.clone(), self.targets.clone())
76 }
77}
78
79impl<B: Backend> Adaptor<PerplexityInput<B>> for ClassificationOutput<B> {
80 fn adapt(&self) -> PerplexityInput<B> {
81 PerplexityInput::new(self.output.clone(), self.targets.clone())
82 }
83}
84
85impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
86 fn adapt(&self) -> ConfusionStatsInput<B> {
87 let [_, num_classes] = self.output.dims();
88 if num_classes > 1 {
89 ConfusionStatsInput::new(
90 self.output.clone(),
91 self.targets.clone().one_hot(num_classes).bool(),
92 )
93 } else {
94 ConfusionStatsInput::new(
95 self.output.clone(),
96 self.targets.clone().unsqueeze_dim(1).bool(),
97 )
98 }
99 }
100}
101
102#[derive(new)]
111pub struct MultiLabelClassificationOutput<B: Backend> {
112 pub loss: Tensor<B, 1>,
114
115 pub output: Tensor<B, 2>,
117
118 pub targets: Tensor<B, 2, Int>,
120}
121
122impl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {
123 type ItemSync = MultiLabelClassificationOutput<Flex>;
125
126 fn sync(self) -> Self::ItemSync {
127 let [output, loss, targets] = Transaction::default()
128 .register(self.output)
129 .register(self.loss)
130 .register(self.targets)
131 .execute()
132 .try_into()
133 .expect("Correct amount of tensor data");
134
135 let device = &Default::default();
136
137 MultiLabelClassificationOutput {
138 output: Tensor::from_data(output, device),
139 loss: Tensor::from_data(loss, device),
140 targets: Tensor::from_data(targets, device),
141 }
142 }
143}
144
145impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
146 fn adapt(&self) -> HammingScoreInput<B> {
147 HammingScoreInput::new(self.output.clone(), self.targets.clone())
148 }
149}
150
151impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
152 fn adapt(&self) -> LossInput<B> {
153 LossInput::new(self.loss.clone())
154 }
155}
156
157impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
158 fn adapt(&self) -> ConfusionStatsInput<B> {
159 ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
160 }
161}